diff --git a/doc/arithmetic-convention.nblink b/doc/arithmetic-convention.nblink new file mode 100644 index 00000000..39928e92 --- /dev/null +++ b/doc/arithmetic-convention.nblink @@ -0,0 +1,3 @@ +{ + "path": "../examples/arithmetic-convention.ipynb" +} diff --git a/doc/index.rst b/doc/index.rst index fd7f9ed8..70b8b439 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -108,10 +108,10 @@ This package is published under MIT license. :caption: User Guide user-guide + arithmetic-convention creating-variables creating-expressions creating-constraints - coordinate-alignment sos-constraints piecewise-linear-constraints piecewise-linear-constraints-tutorial diff --git a/doc/missing-data.nblink b/doc/missing-data.nblink new file mode 100644 index 00000000..64befb24 --- /dev/null +++ b/doc/missing-data.nblink @@ -0,0 +1,3 @@ +{ + "path": "../examples/missing-data.ipynb" +} diff --git a/examples/_nan-edge-cases.ipynb b/examples/_nan-edge-cases.ipynb new file mode 100644 index 00000000..8b037a40 --- /dev/null +++ b/examples/_nan-edge-cases.ipynb @@ -0,0 +1,1192 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro", + "metadata": {}, + "source": "# NaN Edge Cases: Legacy vs v1\n\nDevelopment notebook investigating how NaN behaves across linopy operations under both conventions.\n\n**Core principle (v1):** NaN means \"absent term\" — not a numeric value. It enters only through structural operations (`shift`, `where`, `reindex`, `mask=`) and propagates via IEEE semantics. Absent terms don't poison valid terms at the same coordinate.\n\n1. [Sources of NaN](#sources-of-nan)\n2. [isnull detection](#isnull-detection)\n3. [Arithmetic on shifted expressions](#arithmetic-on-shifted-expressions)\n4. [Combining expressions with absent terms](#combining-expressions-with-absent-terms)\n5. [Constraints from expressions with NaN](#constraints-from-expressions-with-nan)\n6. [Reviving absent slots with fillna and fill_value](#reviving-absent-slots)\n7. [FILL_VALUE internals](#fill_value-internals)" + }, + { + "cell_type": "code", + "id": "imports", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:47.306077Z", + "iopub.status.busy": "2026-03-12T08:25:47.305688Z", + "iopub.status.idle": "2026-03-12T08:25:47.906314Z", + "shell.execute_reply": "2026-03-12T08:25:47.906090Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.585725Z", + "start_time": "2026-03-14T17:35:54.358954Z" + } + }, + "source": "import warnings\n\nimport pandas as pd\nimport xarray as xr\n\nimport linopy\nfrom linopy import Model\nfrom linopy.config import LinopyDeprecationWarning\nfrom linopy.expressions import FILL_VALUE\n\nwarnings.filterwarnings(\"ignore\", category=LinopyDeprecationWarning)", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MindOpt 2.2.0 | 2e28db43, Aug 29 2025, 14:27:12 | arm64 - macOS 26.2\n", + "Start license validation (current time : 14-MAR-2026 18:35:55 UTC+0100).\n", + "[WARN ] No license file is found.\n", + "[ERROR] No valid license was found. Please visit https://opt.aliyun.com/doc/latest/en/html/installation/license.html to apply for and set up your license.\n", + "License validation terminated. Time : 0.000s\n", + "\n" + ] + } + ], + "execution_count": null + }, + { + "cell_type": "code", + "id": "setup", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:47.907490Z", + "iopub.status.busy": "2026-03-12T08:25:47.907377Z", + "iopub.status.idle": "2026-03-12T08:25:47.938441Z", + "shell.execute_reply": "2026-03-12T08:25:47.938258Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.594006Z", + "start_time": "2026-03-14T17:35:55.590699Z" + } + }, + "source": "def make_model():\n m = Model()\n time = pd.RangeIndex(5, name=\"time\")\n x = m.add_variables(lower=0, coords=[time], name=\"x\")\n return m, x", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "shift-header", + "metadata": {}, + "source": "---\n\n## Sources of NaN\n\n### shift\n\n`.shift()` is the primary structural source of NaN. It shifts data along a dimension, creating a gap filled with `FILL_VALUE` (`vars=-1`, `coeffs=NaN`, `const=NaN`)." + }, + { + "cell_type": "code", + "id": "shift-demo", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:47.939568Z", + "iopub.status.busy": "2026-03-12T08:25:47.939444Z", + "iopub.status.idle": "2026-03-12T08:25:47.945428Z", + "shell.execute_reply": "2026-03-12T08:25:47.945260Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.689874Z", + "start_time": "2026-03-14T17:35:55.603203Z" + } + }, + "source": "m, x = make_model()\nexpr = 2 * x + 10\nexpr.shift(time=1)", + "outputs": [ + { + "data": { + "text/plain": [ + "LinearExpression [time: 5]:\n", + "---------------------------\n", + "[0]: None\n", + "[1]: +2 x[0] + 10\n", + "[2]: +2 x[1] + 10\n", + "[3]: +2 x[2] + 10\n", + "[4]: +2 x[3] + 10" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": null + }, + { + "cell_type": "code", + "id": "shift-variable", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:47.946355Z", + "iopub.status.busy": "2026-03-12T08:25:47.946298Z", + "iopub.status.idle": "2026-03-12T08:25:47.948146Z", + "shell.execute_reply": "2026-03-12T08:25:47.947974Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.700844Z", + "start_time": "2026-03-14T17:35:55.694910Z" + } + }, + "source": "# Variables also support shift — labels get -1 sentinel, bounds get NaN\nx.shift(time=1)", + "outputs": [ + { + "data": { + "text/plain": [ + "Variable (time: 5) - 1 masked entries\n", + "-------------------------------------\n", + "[0]: None\n", + "[1]: x[0] ∈ [0, inf]\n", + "[2]: x[1] ∈ [0, inf]\n", + "[3]: x[2] ∈ [0, inf]\n", + "[4]: x[3] ∈ [0, inf]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "roll-header", + "metadata": {}, + "source": "### roll\n\n`.roll()` is circular — values wrap around, no NaN introduced." + }, + { + "cell_type": "code", + "id": "roll-demo", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:47.949110Z", + "iopub.status.busy": "2026-03-12T08:25:47.949039Z", + "iopub.status.idle": "2026-03-12T08:25:47.954063Z", + "shell.execute_reply": "2026-03-12T08:25:47.953891Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.723168Z", + "start_time": "2026-03-14T17:35:55.709576Z" + } + }, + "source": "m, x = make_model()\n(2 * x + 10).roll(time=1)", + "outputs": [ + { + "data": { + "text/plain": [ + "LinearExpression [time: 5]:\n", + "---------------------------\n", + "[0]: +2 x[4] + 10\n", + "[1]: +2 x[0] + 10\n", + "[2]: +2 x[1] + 10\n", + "[3]: +2 x[2] + 10\n", + "[4]: +2 x[3] + 10" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "where-header", + "metadata": {}, + "source": "### where\n\n`.where(cond)` masks slots where the condition is False → `vars=-1, coeffs=NaN, const=NaN`." + }, + { + "cell_type": "code", + "id": "where-demo", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:47.955033Z", + "iopub.status.busy": "2026-03-12T08:25:47.954967Z", + "iopub.status.idle": "2026-03-12T08:25:47.960120Z", + "shell.execute_reply": "2026-03-12T08:25:47.959950Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.750248Z", + "start_time": "2026-03-14T17:35:55.738201Z" + } + }, + "source": "m, x = make_model()\nmask = xr.DataArray([True, True, False, False, True], dims=[\"time\"])\n(2 * x + 10).where(mask)", + "outputs": [ + { + "data": { + "text/plain": [ + "LinearExpression [time: 5]:\n", + "---------------------------\n", + "[0]: +2 x[0] + 10\n", + "[1]: +2 x[1] + 10\n", + "[2]: None\n", + "[3]: None\n", + "[4]: +2 x[4] + 10" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "reindex-header", + "metadata": {}, + "source": "### reindex\n\n`.reindex()` expands or shrinks coordinates. New coordinates get `FILL_VALUE`." + }, + { + "cell_type": "code", + "id": "reindex-demo", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:47.961042Z", + "iopub.status.busy": "2026-03-12T08:25:47.960980Z", + "iopub.status.idle": "2026-03-12T08:25:47.967846Z", + "shell.execute_reply": "2026-03-12T08:25:47.967693Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.782002Z", + "start_time": "2026-03-14T17:35:55.768465Z" + } + }, + "source": "m, x = make_model()\nexpr = 2 * x + 10\n\n# Expand to a larger index — new positions [5, 6] are absent\nexpr.reindex({\"time\": pd.RangeIndex(7, name=\"time\")})", + "outputs": [ + { + "data": { + "text/plain": [ + "LinearExpression [time: 7]:\n", + "---------------------------\n", + "[0]: +2 x[0] + 10\n", + "[1]: +2 x[1] + 10\n", + "[2]: +2 x[2] + 10\n", + "[3]: +2 x[3] + 10\n", + "[4]: +2 x[4] + 10\n", + "[5]: None\n", + "[6]: None" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "isnull-header", + "metadata": {}, + "source": "---\n\n## isnull detection\n\n`isnull()` checks: `(vars == -1).all(helper_dims) & const.isnull()`\n\nBoth conditions must be true — a slot is only \"absent\" if there are no variables AND no constant. This distinguishes \"absent\" from \"valid expression with zero constant\"." + }, + { + "cell_type": "code", + "id": "isnull-demo", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:47.968881Z", + "iopub.status.busy": "2026-03-12T08:25:47.968800Z", + "iopub.status.idle": "2026-03-12T08:25:47.974292Z", + "shell.execute_reply": "2026-03-12T08:25:47.974130Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.805454Z", + "start_time": "2026-03-14T17:35:55.785946Z" + } + }, + "source": "m, x = make_model()\nshifted = (2 * x + 10).shift(time=2)\nshifted.isnull()", + "outputs": [ + { + "data": { + "text/plain": [ + " Size: 5B\n", + "array([ True, True, False, False, False])\n", + "Coordinates:\n", + " * time (time) int64 40B 0 1 2 3 4" + ], + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray (time: 5)> Size: 5B\n",
+       "array([ True,  True, False, False, False])\n",
+       "Coordinates:\n",
+       "  * time     (time) int64 40B 0 1 2 3 4
" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "arithmetic-header", + "metadata": {}, + "source": "---\n\n## Arithmetic on shifted expressions\n\nWhen you do arithmetic on an expression with absent slots (from `shift`/`where`/`reindex`):\n\n- **Addition/subtraction**: fills const with 0 (additive identity) before adding. This preserves associativity: `(a + b) + c == a + (b + c)`.\n- **Multiplication/division**: NaN propagates. No implicit fill — the \"right\" neutral element depends on context (0 kills, 1 preserves).\n\nLegacy mode fills all NaN with neutral elements for both add and mul." + }, + { + "cell_type": "code", + "id": "arithmetic-legacy", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:47.975240Z", + "iopub.status.busy": "2026-03-12T08:25:47.975181Z", + "iopub.status.idle": "2026-03-12T08:25:47.983757Z", + "shell.execute_reply": "2026-03-12T08:25:47.983582Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.830568Z", + "start_time": "2026-03-14T17:35:55.813713Z" + } + }, + "source": "linopy.options[\"arithmetic_convention\"] = \"legacy\"\nm, x = make_model()\nshifted = (2 * x + 10).shift(time=1)\n\n# Legacy: NaN const filled with 0, then +5 = 5. Slot looks alive!\nshifted + 5", + "outputs": [ + { + "data": { + "text/plain": [ + "LinearExpression [time: 5]:\n", + "---------------------------\n", + "[0]: +5\n", + "[1]: +2 x[0] + 15\n", + "[2]: +2 x[1] + 15\n", + "[3]: +2 x[2] + 15\n", + "[4]: +2 x[3] + 15" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": null + }, + { + "cell_type": "code", + "id": "arithmetic-v1", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:47.984647Z", + "iopub.status.busy": "2026-03-12T08:25:47.984591Z", + "iopub.status.idle": "2026-03-12T08:25:47.992694Z", + "shell.execute_reply": "2026-03-12T08:25:47.992528Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.861842Z", + "start_time": "2026-03-14T17:35:55.839520Z" + } + }, + "source": "linopy.options[\"arithmetic_convention\"] = \"v1\"\nm, x = make_model()\nshifted = (2 * x + 10).shift(time=1)\n\n# v1: addition fills const with 0 (additive identity), then adds 5\nshifted + 5", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "j9aln31mkog", + "source": "# v1: multiplication propagates NaN — absent stays absent\nshifted * 3", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "eodco2pcrqn", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:47.993521Z", + "iopub.status.busy": "2026-03-12T08:25:47.993467Z", + "iopub.status.idle": "2026-03-12T08:25:48.002848Z", + "shell.execute_reply": "2026-03-12T08:25:48.002675Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T16:37:06.445775Z", + "start_time": "2026-03-14T16:37:06.380780Z" + } + }, + "source": "---\n\n## Combining expressions with absent terms\n\nWhen two expressions are merged (e.g., `x + y.shift(1)`), each term is concatenated along the `_term` dimension. The constant is summed with `skipna=True` — NaN from one operand does **not** poison the other.\n\n**Key rule: absent terms don't poison valid terms at the same coordinate.**" + }, + { + "cell_type": "code", + "id": "qfgxszizmcf", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:48.003734Z", + "iopub.status.busy": "2026-03-12T08:25:48.003675Z", + "iopub.status.idle": "2026-03-12T08:25:48.011267Z", + "shell.execute_reply": "2026-03-12T08:25:48.011094Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.903304Z", + "start_time": "2026-03-14T17:35:55.878230Z" + } + }, + "source": [ + "linopy.options[\"arithmetic_convention\"] = \"v1\"\n", + "m, x = make_model()\n", + "y = m.add_variables(lower=0, coords=[pd.RangeIndex(5, name=\"time\")], name=\"y\")\n", + "\n", + "# x is valid everywhere, y.shift(1) is absent at time=0\n", + "# → time=0 still has x's term, only y's term is absent\n", + "x + (1 * y).shift(time=1)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "mwjx9or4azm", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:48.012148Z", + "iopub.status.busy": "2026-03-12T08:25:48.012093Z", + "iopub.status.idle": "2026-03-12T08:25:48.020636Z", + "shell.execute_reply": "2026-03-12T08:25:48.020460Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.937575Z", + "start_time": "2026-03-14T17:35:55.913528Z" + } + }, + "source": "# Shifted constant is LOST at the gap:\n# (y+5).shift makes the ENTIRE expression absent at time=0 — including its constant.\n# Only the outer +5 survives. time=1..4 get const=10 (shifted 5 + outer 5).\nx + (1 * y + 5).shift(time=1) + 5", + "outputs": [ + { + "data": { + "text/plain": [ + "LinearExpression [time: 5]:\n", + "---------------------------\n", + "[0]: +1 x[0] + 5\n", + "[1]: +1 x[1] + 1 y[0] + 10\n", + "[2]: +1 x[2] + 1 y[1] + 10\n", + "[3]: +1 x[3] + 1 y[2] + 10\n", + "[4]: +1 x[4] + 1 y[3] + 10" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": null + }, + { + "cell_type": "code", + "id": "h9wto4skk5s", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:48.021484Z", + "iopub.status.busy": "2026-03-12T08:25:48.021426Z", + "iopub.status.idle": "2026-03-12T08:25:48.029464Z", + "shell.execute_reply": "2026-03-12T08:25:48.029305Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.968359Z", + "start_time": "2026-03-14T17:35:55.943359Z" + } + }, + "source": "# Both expressions shifted — time=0 is fully absent (all terms absent AND const=NaN)\nresult = (1 * x).shift(time=1) + (1 * y).shift(time=1)\nprint(\"isnull:\", result.isnull().values)\nresult", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "j00yil0a95", + "metadata": {}, + "source": "### Summary\n\n| Expression | const at time=0 | isnull at time=0 | Why |\n|---|---|---|---|\n| `x + y.shift(1)` | 0 | False | y's term absent, x valid, const sum skips NaN |\n| `x + y.shift(1) + 5` | 5 | False | Same, then +5 on const |\n| `x + (y+5).shift(1) + 5` | 5 | False | Shifted const (5) is lost — only outer +5 survives |\n| `x.shift(1) + y.shift(1)` | NaN | True | All terms absent AND all consts NaN → fully absent |" + }, + { + "cell_type": "markdown", + "id": "key-difference", + "metadata": {}, + "source": "### Legacy vs v1: scalar arithmetic on shifted expressions\n\n| | Legacy | v1 |\n|---|---|---|\n| `shifted + 5` at absent slot | const=5 (alive) | const=5 (alive, additive identity fill) |\n| `shifted * 3` at absent slot | coeffs=0, const=0 (alive) | coeffs=NaN, const=NaN (absent) |\n| `shifted - 5` at absent slot | const=-5 (alive) | const=-5 (alive, additive identity fill) |\n| `shifted / 2` at absent slot | coeffs=0, const=0 (alive) | coeffs=NaN, const=NaN (absent) |\n\n**v1 rule:** addition/subtraction use 0 as additive identity to fill NaN const. Multiplication/division propagate NaN — use `.fillna(value)` or `.mul(v, fill_value=)` for explicit control." + }, + { + "cell_type": "markdown", + "id": "constraint-header", + "metadata": {}, + "source": "---\n\n## Constraints from expressions with NaN\n\nAbsent slots in expressions propagate to constraint RHS. The preferred approach is to avoid NaN entirely using `isel` + positional alignment, or to filter with `.sel()`." + }, + { + "cell_type": "code", + "id": "constraint-demo", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:48.030525Z", + "iopub.status.busy": "2026-03-12T08:25:48.030458Z", + "iopub.status.idle": "2026-03-12T08:25:48.043325Z", + "shell.execute_reply": "2026-03-12T08:25:48.043168Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:55.998673Z", + "start_time": "2026-03-14T17:35:55.979979Z" + } + }, + "source": "linopy.options[\"arithmetic_convention\"] = \"v1\"\nm, x = make_model()\n\n# Preferred: isel + override avoids NaN entirely\nx_now = 1 * x.isel(time=slice(1, None))\nx_prev = 1 * x.isel(time=slice(None, -1))\nramp = x_now.sub(x_prev, join=\"override\")\nramp", + "outputs": [ + { + "data": { + "text/plain": [ + "LinearExpression [time: 4]:\n", + "---------------------------\n", + "[1]: +1 x[1] - 1 x[0]\n", + "[2]: +1 x[2] - 1 x[1]\n", + "[3]: +1 x[3] - 1 x[2]\n", + "[4]: +1 x[4] - 1 x[3]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": null + }, + { + "cell_type": "code", + "id": "constraint-fix", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:48.044184Z", + "iopub.status.busy": "2026-03-12T08:25:48.044131Z", + "iopub.status.idle": "2026-03-12T08:25:48.060912Z", + "shell.execute_reply": "2026-03-12T08:25:48.060763Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:56.020722Z", + "start_time": "2026-03-14T17:35:56.009573Z" + } + }, + "source": "# Alternative: filter absent slots with .sel() after shift\nshifted = (1 * x).shift(time=1)\nvalid = ~shifted.isnull()\nshifted.sel(time=valid)", + "outputs": [ + { + "data": { + "text/plain": [ + "LinearExpression [time: 4]:\n", + "---------------------------\n", + "[1]: +1 x[0]\n", + "[2]: +1 x[1]\n", + "[3]: +1 x[2]\n", + "[4]: +1 x[3]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "sanitize-header", + "metadata": {}, + "source": "---\n\n## Reviving absent slots\n\nAddition/subtraction automatically fill const with 0 (additive identity) — this is not arbitrary, it preserves associativity.\n\nFor multiplication/division, NaN propagates. To revive absent slots before multiplying:\n\n- **`.fillna(value)`** — fill before arithmetic. Works on both Variables and Expressions. `Variable.fillna(numeric)` returns a `LinearExpression`.\n- **`.mul(value, fill_value=)`** — fill and multiply in one step." + }, + { + "cell_type": "code", + "id": "sanitize-demo", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:48.061831Z", + "iopub.status.busy": "2026-03-12T08:25:48.061768Z", + "iopub.status.idle": "2026-03-12T08:25:48.069806Z", + "shell.execute_reply": "2026-03-12T08:25:48.069649Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:56.056198Z", + "start_time": "2026-03-14T17:35:56.030610Z" + } + }, + "source": "linopy.options[\"arithmetic_convention\"] = \"v1\"\nm, x = make_model()\nshifted = (1 * x).shift(time=1)\n\n# Multiplication propagates NaN — absent stays absent\nshifted * 3", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "epev84h04pn", + "source": "# fillna(0) revives, then multiplication works\nshifted.fillna(0) * 3", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-14T17:35:56.068769Z", + "start_time": "2026-03-14T17:35:56.060454Z" + } + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "o542kxv546", + "source": "# Shorthand: .mul(value, fill_value=) does both in one step\nshifted.mul(3, fill_value=0)", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-14T17:35:56.127017Z", + "start_time": "2026-03-14T17:35:56.076931Z" + } + }, + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "3vl823ewxsx", + "source": "# Variable.fillna(numeric) returns a LinearExpression\nx.shift(time=1).fillna(0)", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "bpifee3jcir", + "source": "### Outer join with fill_value\n\nWhen combining expressions with mismatched coordinates, absent terms on each side don't poison valid terms. The outer join preserves all coordinates.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "4psrcv8pjn8", + "source": "m = Model()\ntech_a, tech_b = [\"wind\", \"solar\"], [\"solar\", \"gas\"]\ncap_a = m.add_variables(lower=0, coords=[tech_a], name=\"cap_a\")\ncap_b = m.add_variables(lower=0, coords=[tech_b], name=\"cap_b\")\ncost_a = xr.DataArray([10, 20], coords=[(\"dim_0\", tech_a)])\ncost_b = xr.DataArray([15, 25], coords=[(\"dim_0\", tech_b)])\n\n# Outer join: each tech keeps its valid terms, absent terms are ignored at solve time\ncombined = (cap_a * cost_a).add(cap_b * cost_b, join=\"outer\")\nprint(\"isnull:\", combined.isnull().values)\ncombined", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "fillvalue-header", + "metadata": {}, + "source": "---\n\n## FILL_VALUE internals\n\n| Type | Field | FILL_VALUE | Why |\n|---|---|---|---|\n| LinearExpression | `vars` | -1 | Integer sentinel (no variable) |\n| LinearExpression | `coeffs` | NaN | Absent — not a numeric value |\n| LinearExpression | `const` | NaN | Absent — needed for `isnull()` detection |\n| Variable | `labels` | -1 | Integer sentinel (no variable) |\n| Variable | `lower` | NaN | Absent bound |\n| Variable | `upper` | NaN | Absent bound |\n\nAll float fields use NaN for absence. Integer fields use -1." + }, + { + "cell_type": "code", + "id": "fillvalue-demo", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:48.070762Z", + "iopub.status.busy": "2026-03-12T08:25:48.070706Z", + "iopub.status.idle": "2026-03-12T08:25:48.077412Z", + "shell.execute_reply": "2026-03-12T08:25:48.077245Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:56.142163Z", + "start_time": "2026-03-14T17:35:56.139464Z" + } + }, + "source": "print(\"FILL_VALUE:\", FILL_VALUE)", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FILL_VALUE: {'vars': -1, 'coeffs': nan, 'const': nan}\n" + ] + } + ], + "execution_count": null + }, + { + "cell_type": "code", + "id": "cleanup", + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-12T08:25:48.078298Z", + "iopub.status.busy": "2026-03-12T08:25:48.078237Z", + "iopub.status.idle": "2026-03-12T08:25:48.079577Z", + "shell.execute_reply": "2026-03-12T08:25:48.079408Z" + }, + "ExecuteTime": { + "end_time": "2026-03-14T17:35:56.148607Z", + "start_time": "2026-03-14T17:35:56.146368Z" + } + }, + "source": "linopy.options.reset()", + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/arithmetic-convention.ipynb b/examples/arithmetic-convention.ipynb new file mode 100644 index 00000000..42a2bc3a --- /dev/null +++ b/examples/arithmetic-convention.ipynb @@ -0,0 +1,922 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro", + "metadata": {}, + "source": "# Arithmetic Convention\n\nlinopy is transitioning to a stricter arithmetic convention for coordinate alignment. This notebook covers:\n\n1. [How to opt in](#how-to-opt-in) to the new behavior\n2. [v1 convention](#v1-convention-the-future-default) — strict coordinate matching (the future default)\n3. [Legacy convention](#legacy-convention-current-default) — the current default behavior\n4. [The `join` parameter](#the-join-parameter) — explicit control over alignment\n5. [Migration guide](#migration-guide) — updating your code\n\nFor NaN handling and masking, see [Missing Data](missing-data.ipynb)." + }, + { + "cell_type": "code", + "id": "imports", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.056995Z", + "start_time": "2026-03-11T14:44:59.298634Z" + } + }, + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr\n", + "\n", + "import linopy" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "opt-in-header", + "metadata": {}, + "source": [ + "## How to opt in\n", + "\n", + "linopy uses a global setting to control arithmetic behavior. The default is `\"legacy\"` (backward-compatible). To enable the new strict convention:" + ] + }, + { + "cell_type": "code", + "id": "opt-in", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.064201Z", + "start_time": "2026-03-11T14:45:00.062580Z" + } + }, + "source": [ + "linopy.options[\"arithmetic_convention\"] = \"v1\"" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "opt-in-explain", + "metadata": {}, + "source": [ + "Put this at the top of your script, before any arithmetic. Under `\"legacy\"`, all legacy codepaths emit a `LinopyDeprecationWarning` to help you find code that needs updating.\n", + "\n", + "To silence the warnings without migrating yet:\n", + "\n", + "```python\n", + "import warnings\n", + "warnings.filterwarnings('ignore', category=linopy.LinopyDeprecationWarning)\n", + "```\n", + "\n", + "**Rollout plan:**\n", + "- **Now**: `\"legacy\"` is the default — nothing breaks.\n", + "- **linopy v1**: `\"v1\"` becomes the default, `\"legacy\"` is removed." + ] + }, + { + "cell_type": "markdown", + "id": "v1-header", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## v1 convention (the future default)\n", + "\n", + "Two rules:\n", + "\n", + "1. **Shared dimensions must match exactly.** When two operands share a dimension, their coordinate labels must be identical. A `ValueError` is raised on mismatch.\n", + "2. **Non-shared dimensions broadcast freely.** When dimensions are not shared, operands broadcast over the missing dimensions — for both expressions and constants.\n", + "\n", + "This ensures mismatches never silently produce wrong results, while preserving all standard algebraic laws.\n", + "\n", + "Inspired by [pyoframe](https://github.com/Bravos-Power/pyoframe)." + ] + }, + { + "cell_type": "code", + "id": "v1-setup", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.110784Z", + "start_time": "2026-03-11T14:45:00.067922Z" + } + }, + "source": [ + "m = linopy.Model()\n", + "\n", + "time = pd.RangeIndex(5, name=\"time\")\n", + "techs = pd.Index([\"solar\", \"wind\", \"gas\"], name=\"tech\")\n", + "scenarios = pd.Index([\"low\", \"high\"], name=\"scenario\")\n", + "\n", + "x = m.add_variables(lower=0, coords=[time], name=\"x\")\n", + "y = m.add_variables(lower=0, coords=[time], name=\"y\")\n", + "gen = m.add_variables(lower=0, coords=[time, techs], name=\"gen\")\n", + "risk = m.add_variables(lower=0, coords=[techs, scenarios], name=\"risk\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "v1-works-header", + "metadata": {}, + "source": [ + "### What works" + ] + }, + { + "cell_type": "code", + "id": "v1-same-coords", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.125467Z", + "start_time": "2026-03-11T14:45:00.114440Z" + } + }, + "source": [ + "# Same coords — just works\n", + "x + y" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "v1-matching-constant", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.137242Z", + "start_time": "2026-03-11T14:45:00.129823Z" + } + }, + "source": [ + "# Constant with matching coords\n", + "factor = xr.DataArray([2, 3, 4, 5, 6], dims=[\"time\"], coords={\"time\": time})\n", + "x * factor" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "v1-broadcast-constant", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.155873Z", + "start_time": "2026-03-11T14:45:00.145769Z" + } + }, + "source": [ + "# Constant with fewer dims — broadcasts freely\n", + "cost = xr.DataArray([1.0, 0.5, 3.0], dims=[\"tech\"], coords={\"tech\": techs})\n", + "gen * cost # cost broadcasts over time" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "v1-broadcast-expr", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.178325Z", + "start_time": "2026-03-11T14:45:00.166370Z" + } + }, + "source": [ + "# Expression + Expression with non-shared dims — broadcasts freely\n", + "gen + risk # (time, tech) + (tech, scenario) → (time, tech, scenario)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "v1-scalar", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.190629Z", + "start_time": "2026-03-11T14:45:00.184831Z" + } + }, + "source": [ + "# Scalar — always fine\n", + "x + 5" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "v1-constraint-broadcast", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.215193Z", + "start_time": "2026-03-11T14:45:00.202065Z" + } + }, + "source": [ + "# Constraints — RHS with fewer dims broadcasts naturally\n", + "capacity = xr.DataArray([100, 80, 50], dims=[\"tech\"], coords={\"tech\": techs})\n", + "m.add_constraints(gen <= capacity, name=\"cap\") # capacity broadcasts over time" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "v1-raises-header", + "metadata": {}, + "source": [ + "### What raises an error" + ] + }, + { + "cell_type": "code", + "id": "v1-mismatch-expr", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.228580Z", + "start_time": "2026-03-11T14:45:00.221751Z" + } + }, + "source": [ + "y_short = m.add_variables(\n", + " lower=0, coords=[pd.RangeIndex(3, name=\"time\")], name=\"y_short\"\n", + ")\n", + "\n", + "try:\n", + " x + y_short # time coords don't match: [0..4] vs [0..2]\n", + "except ValueError as e:\n", + " print(\"ValueError:\", e)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "v1-mismatch-constant", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.241638Z", + "start_time": "2026-03-11T14:45:00.237972Z" + } + }, + "source": [ + "partial = xr.DataArray([10, 20, 30], dims=[\"time\"], coords={\"time\": [0, 1, 2]})\n", + "\n", + "try:\n", + " x * partial # time coords [0..4] vs [0,1,2]\n", + "except ValueError as e:\n", + " print(\"ValueError:\", e)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "v1-mismatch-constraint", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.257815Z", + "start_time": "2026-03-11T14:45:00.253230Z" + } + }, + "source": [ + "try:\n", + " x <= partial # constraint RHS doesn't cover all coords\n", + "except ValueError as e:\n", + " print(\"ValueError:\", e)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "v1-nan-header", + "metadata": {}, + "source": "### NaN in arithmetic\n\nUnder v1, NaN values in arithmetic operands **raise a `ValueError`** — they are not silently replaced. See [Missing Data](missing-data.ipynb) for details on handling NaN and masking." + }, + { + "cell_type": "code", + "id": "v1-nan", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.273485Z", + "start_time": "2026-03-11T14:45:00.268607Z" + } + }, + "source": "vals = xr.DataArray([1.0, np.nan, 3.0, 4.0, 5.0], dims=[\"time\"], coords={\"time\": time})\n\ntry:\n x + vals\nexcept ValueError as e:\n print(\"ValueError:\", e)", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ValueError: Constant contains NaN values. Use .fillna() to handle missing values before arithmetic operations.\n" + ] + } + ], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "v1-escape-header", + "metadata": {}, + "source": "### Escape hatches for coordinate mismatches\n\nWhen coordinates don't match, you have several options:" + }, + { + "cell_type": "markdown", + "id": "v1-sel", + "metadata": {}, + "source": [ + "**Option 1: `.sel()` — subset before operating**\n", + "\n", + "The cleanest way. Explicitly select matching coordinates:" + ] + }, + { + "cell_type": "code", + "id": "v1-sel-example", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.293520Z", + "start_time": "2026-03-11T14:45:00.285174Z" + } + }, + "source": [ + "x.sel(time=[0, 1, 2]) + y_short" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "v1-join", + "metadata": {}, + "source": [ + "**Option 2: Named methods with `join=`**\n", + "\n", + "All arithmetic operations have named-method equivalents (`.add()`, `.sub()`, `.mul()`, `.div()`, `.le()`, `.ge()`, `.eq()`) that accept a `join` parameter:" + ] + }, + { + "cell_type": "code", + "id": "v1-join-example", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.315957Z", + "start_time": "2026-03-11T14:45:00.307311Z" + } + }, + "source": [ + "x.add(y_short, join=\"inner\") # intersection: time [0, 1, 2]" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "v1-join-outer", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.325706Z", + "start_time": "2026-03-11T14:45:00.319864Z" + } + }, + "source": [ + "x.mul(partial, join=\"left\") # keep x's coords, fill missing with 0" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "v1-assign-coords", + "metadata": {}, + "source": [ + "**Option 3: `.assign_coords()` — positional alignment**\n", + "\n", + "When two operands have the same shape but different labels, relabel one to match the other:" + ] + }, + { + "cell_type": "code", + "id": "v1-assign-coords-example", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.353860Z", + "start_time": "2026-03-11T14:45:00.344281Z" + } + }, + "source": [ + "z = m.add_variables(lower=0, coords=[pd.RangeIndex(5, 10, name=\"time\")], name=\"z\")\n", + "\n", + "# z has time=[5..9], x has time=[0..4] — same shape, different labels\n", + "x + z.assign_coords(time=x.coords[\"time\"])" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "v1-align", + "metadata": {}, + "source": [ + "**Option 4: `linopy.align()` — multi-operand pre-alignment**" + ] + }, + { + "cell_type": "code", + "id": "v1-align-example", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.371208Z", + "start_time": "2026-03-11T14:45:00.361638Z" + } + }, + "source": [ + "x_aligned, y_short_aligned = linopy.align(x, y_short, join=\"outer\")\n", + "x_aligned + y_short_aligned" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "v1-algebraic", + "metadata": {}, + "source": [ + "### Algebraic properties\n", + "\n", + "All standard algebraic laws hold under v1. You can freely refactor expressions without worrying about dimension ordering.\n", + "\n", + "| Property | Example |\n", + "|---|---|\n", + "| **Commutativity of +** | `x + y == y + x` |\n", + "| **Commutativity of ×** | `x * c == c * x` |\n", + "| **Associativity of +** | `(x + y) + z == x + (y + z)` |\n", + "| **Scalar distributivity** | `s * (x + y) == s*x + s*y` |\n", + "| **Constant distributivity** | `c[B] * (x[A] + g[A,B]) == c[B]*x[A] + c[B]*g[A,B]` |\n", + "| **Additive identity** | `x + 0 == x` |\n", + "| **Multiplicative identity** | `x * 1 == x` |\n", + "| **Double negation** | `-(-x) == x` |\n", + "| **Zero** | `x * 0 == 0` |\n", + "\n", + "**Caveat:** These guarantees only hold for operations involving at least one linopy object. Operations between plain constants (`DataArray + DataArray`) use their library's own rules. To enforce strict matching for xarray operations too:\n", + "\n", + "```python\n", + "xr.set_options(arithmetic_join=\"exact\")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "legacy-header", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Legacy convention (current default)\n", + "\n", + "The legacy convention is the current default (`linopy.options[\"arithmetic_convention\"] = \"legacy\"`). It uses heuristics to handle coordinate mismatches silently. This section describes its behavior for users who haven't migrated yet.\n", + "\n", + "Under legacy, all arithmetic operations emit a `LinopyDeprecationWarning`." + ] + }, + { + "cell_type": "code", + "id": "legacy-switch", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.376208Z", + "start_time": "2026-03-11T14:45:00.374668Z" + } + }, + "source": [ + "import warnings\n", + "\n", + "linopy.options[\"arithmetic_convention\"] = \"legacy\"\n", + "warnings.filterwarnings(\"ignore\", category=linopy.LinopyDeprecationWarning)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "legacy-setup", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.390854Z", + "start_time": "2026-03-11T14:45:00.386569Z" + } + }, + "source": [ + "m2 = linopy.Model()\n", + "time = pd.RangeIndex(5, name=\"time\")\n", + "x2 = m2.add_variables(lower=0, coords=[time], name=\"x\")\n", + "y2_short = m2.add_variables(\n", + " lower=0, coords=[pd.RangeIndex(3, name=\"time\")], name=\"y_short\"\n", + ")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "legacy-size-aware", + "metadata": {}, + "source": [ + "### Size-aware alignment\n", + "\n", + "When two operands share a dimension:\n", + "- **Same size**: positional alignment (labels ignored, left operand's labels kept)\n", + "- **Different size**: left-join (reindex to the left operand's coordinates, fill with zeros)" + ] + }, + { + "cell_type": "code", + "id": "legacy-subset", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.400749Z", + "start_time": "2026-03-11T14:45:00.393413Z" + } + }, + "source": [ + "# Different size — left join, fill missing with 0\n", + "x2 + y2_short # y_short drops out at time 3, 4" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "legacy-same-size", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.413292Z", + "start_time": "2026-03-11T14:45:00.404468Z" + } + }, + "source": [ + "# Same size — positional alignment (labels ignored!)\n", + "z2 = m2.add_variables(lower=0, coords=[pd.RangeIndex(5, 10, name=\"time\")], name=\"z\")\n", + "x2 + z2 # x has time=[0..4], z has time=[5..9], but same size → positional match" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "legacy-nan", + "metadata": {}, + "source": [ + "### NaN filling\n", + "\n", + "NaN values in constants are silently replaced with operation-specific neutral elements:\n", + "- Addition/subtraction: NaN → 0\n", + "- Multiplication: NaN → 0 (zeroes out the variable)\n", + "- Division: NaN → 1 (no scaling)" + ] + }, + { + "cell_type": "code", + "id": "legacy-nan-fill", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.421451Z", + "start_time": "2026-03-11T14:45:00.416080Z" + } + }, + "source": [ + "vals = xr.DataArray([1.0, np.nan, 3.0, 4.0, 5.0], dims=[\"time\"], coords={\"time\": time})\n", + "result = x2 + vals\n", + "print(\"const:\", result.const.values) # NaN replaced with 0" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "legacy-constraint-rhs", + "metadata": {}, + "source": [ + "### Constraint RHS\n", + "\n", + "In constraints, the RHS is reindexed to the expression's coordinates. Missing positions become NaN, which tells linopy to skip those constraints:" + ] + }, + { + "cell_type": "code", + "id": "legacy-constraint", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.435315Z", + "start_time": "2026-03-11T14:45:00.429597Z" + } + }, + "source": [ + "rhs = xr.DataArray([10, 20, 30], dims=[\"time\"], coords={\"time\": [0, 1, 2]})\n", + "con = x2 <= rhs # constraint only at time 0, 1, 2; NaN at time 3, 4\n", + "con" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "legacy-restore-v1", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.447130Z", + "start_time": "2026-03-11T14:45:00.445694Z" + } + }, + "source": [ + "# Switch back to v1 for the rest of the notebook\n", + "linopy.options[\"arithmetic_convention\"] = \"v1\"\n", + "warnings.resetwarnings()" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "join-header", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## The `join` parameter\n", + "\n", + "Both conventions support explicit `join=` on named methods. This overrides the default behavior and works identically in both modes.\n", + "\n", + "| `join` | Coordinates kept | Fill behavior |\n", + "|--------|-----------------|---------------|\n", + "| `\"exact\"` | Must match | `ValueError` if different |\n", + "| `\"inner\"` | Intersection | No fill needed |\n", + "| `\"outer\"` | Union | Fill with neutral element |\n", + "| `\"left\"` | Left operand's | Fill missing right |\n", + "| `\"right\"` | Right operand's | Fill missing left |\n", + "| `\"override\"` | Left operand's (positional) | Positional alignment |" + ] + }, + { + "cell_type": "code", + "id": "join-setup", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.457694Z", + "start_time": "2026-03-11T14:45:00.453096Z" + } + }, + "source": [ + "m3 = linopy.Model()\n", + "\n", + "i_a = pd.Index([0, 1, 2], name=\"i\")\n", + "i_b = pd.Index([1, 2, 3], name=\"i\")\n", + "\n", + "a = m3.add_variables(coords=[i_a], name=\"a\")\n", + "b = m3.add_variables(coords=[i_b], name=\"b\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "join-inner", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.473280Z", + "start_time": "2026-03-11T14:45:00.464589Z" + } + }, + "source": [ + "# Inner join — intersection (i=1, 2)\n", + "a.add(b, join=\"inner\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "join-outer", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.488676Z", + "start_time": "2026-03-11T14:45:00.478816Z" + } + }, + "source": [ + "# Outer join — union (i=0, 1, 2, 3)\n", + "a.add(b, join=\"outer\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "join-left", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.500599Z", + "start_time": "2026-03-11T14:45:00.491930Z" + } + }, + "source": [ + "# Left join — keep a's coords (i=0, 1, 2)\n", + "a.add(b, join=\"left\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "join-right", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.511586Z", + "start_time": "2026-03-11T14:45:00.503595Z" + } + }, + "source": [ + "# Right join — keep b's coords (i=1, 2, 3)\n", + "a.add(b, join=\"right\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "join-override", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.522774Z", + "start_time": "2026-03-11T14:45:00.515038Z" + } + }, + "source": [ + "# Override — positional (0↔1, 1↔2, 2↔3), uses a's labels\n", + "a.add(b, join=\"override\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "join-mul", + "metadata": {}, + "source": [ + "The same `join` parameter works on `.mul()`, `.div()`, `.le()`, `.ge()`, `.eq()`:" + ] + }, + { + "cell_type": "code", + "id": "join-mul-example", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.535578Z", + "start_time": "2026-03-11T14:45:00.528426Z" + } + }, + "source": [ + "const = xr.DataArray([2, 3, 4], dims=[\"i\"], coords={\"i\": [1, 2, 3]})\n", + "\n", + "# Multiply, keeping only shared coords\n", + "a.mul(const, join=\"inner\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "join-constraint", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.543751Z", + "start_time": "2026-03-11T14:45:00.538431Z" + } + }, + "source": [ + "# Constraint with left join — only a's coords, NaN at missing RHS positions\n", + "rhs = xr.DataArray([10, 20], dims=[\"i\"], coords={\"i\": [0, 1]})\n", + "a.le(rhs, join=\"left\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "migration-header", + "metadata": {}, + "source": "---\n\n## Migration guide\n\nTo migrate from legacy to v1:\n\n### Step 1: Enable v1 and run your code\n\n```python\nlinopy.options[\"arithmetic_convention\"] = \"v1\"\n```\n\nAny code that relied on legacy alignment will now raise `ValueError` with a helpful message suggesting which `join=` to use.\n\n### Step 2: Fix coordinate mismatches\n\nCommon patterns:\n\n| Legacy code (silent) | v1 equivalent (explicit) |\n|---|---|\n| `x + subset_constant` | `x.add(subset_constant, join=\"left\")` |\n| `x + y` (same size, different labels) | `x + y.assign_coords(time=x.coords[\"time\"])` |\n| `x <= partial_rhs` | `x.le(partial_rhs, join=\"left\")` |\n| `expr + expr` (mismatched coords) | `expr.add(other, join=\"outer\")` or `.sel()` first |\n\n### Step 3: Handle NaN\n\nUnder legacy, NaN in operands was silently replaced. Under v1, it raises `ValueError`. See [Missing Data](missing-data.ipynb) for the full migration guide.\n\n### Step 4: Pandas index names\n\nUnder v1, pandas objects must have **named indices** to align properly with linopy variables:\n\n```python\n# Will fail — unnamed index becomes \"dim_0\"\ncost = pd.Series([10, 20], index=[\"wind\", \"solar\"])\n\n# Works — explicit dimension name\ncost = pd.Series([10, 20], index=pd.Index([\"wind\", \"solar\"], name=\"tech\"))\n```" + }, + { + "cell_type": "markdown", + "id": "practical-header", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Practical example\n", + "\n", + "A generation dispatch model demonstrating both matching coords and explicit joins." + ] + }, + { + "cell_type": "code", + "id": "practical-setup", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.643343Z", + "start_time": "2026-03-11T14:45:00.638601Z" + } + }, + "source": [ + "m4 = linopy.Model()\n", + "\n", + "hours = pd.RangeIndex(24, name=\"hour\")\n", + "techs = pd.Index([\"solar\", \"wind\", \"gas\"], name=\"tech\")\n", + "\n", + "gen = m4.add_variables(lower=0, coords=[hours, techs], name=\"gen\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "practical-capacity", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.661328Z", + "start_time": "2026-03-11T14:45:00.650715Z" + } + }, + "source": [ + "# Capacity limits — constant broadcasts over hours\n", + "capacity = xr.DataArray([100, 80, 50], dims=[\"tech\"], coords={\"tech\": techs})\n", + "m4.add_constraints(gen <= capacity, name=\"capacity_limit\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "practical-solar", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.680515Z", + "start_time": "2026-03-11T14:45:00.669453Z" + } + }, + "source": [ + "# Solar availability — full 24h profile, matching coords\n", + "solar_avail = np.zeros(24)\n", + "solar_avail[6:19] = 100 * np.sin(np.linspace(0, np.pi, 13))\n", + "solar_availability = xr.DataArray(solar_avail, dims=[\"hour\"], coords={\"hour\": hours})\n", + "\n", + "solar_gen = gen.sel(tech=\"solar\")\n", + "m4.add_constraints(solar_gen <= solar_availability, name=\"solar_avail\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "practical-peak", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:45:00.695847Z", + "start_time": "2026-03-11T14:45:00.684022Z" + } + }, + "source": [ + "# Peak demand — only applies to hours 8-20, use join=\"inner\"\n", + "peak_hours = pd.RangeIndex(8, 21, name=\"hour\")\n", + "peak_demand = xr.DataArray(\n", + " np.full(len(peak_hours), 120.0), dims=[\"hour\"], coords={\"hour\": peak_hours}\n", + ")\n", + "\n", + "total_gen = gen.sum(\"tech\")\n", + "m4.add_constraints(total_gen.ge(peak_demand, join=\"inner\"), name=\"peak_demand\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "summary", + "metadata": {}, + "source": "---\n\n## Summary\n\n| | v1 (future default) | Legacy (current default) |\n|---|---|---|\n| **Mismatched coords** | `ValueError` | Silent left-join / override |\n| **Same-size different labels** | `ValueError` | Positional alignment |\n| **NaN in operands** | `ValueError` ([details](missing-data.ipynb)) | Filled with neutral element |\n| **Explicit join** | `.add(x, join=...)` | `.add(x, join=...)` |\n| **Setting** | `options[\"arithmetic_convention\"] = \"v1\"` | `options[\"arithmetic_convention\"] = \"legacy\"` |" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/coordinate-alignment.ipynb b/examples/coordinate-alignment.ipynb deleted file mode 100644 index 1547bd9d..00000000 --- a/examples/coordinate-alignment.ipynb +++ /dev/null @@ -1,488 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Coordinate Alignment\n", - "\n", - "Since linopy builds on xarray, coordinate alignment matters when combining variables or expressions that live on different coordinates. By default, linopy aligns operands automatically and fills missing entries with sensible defaults. This guide shows how alignment works and how to control it with the ``join`` parameter." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "import xarray as xr\n", - "\n", - "import linopy" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Default Alignment Behavior\n", - "\n", - "When two operands share a dimension but have different coordinates, linopy keeps the **larger** (superset) coordinate range and fills missing positions with zeros (for addition) or zero coefficients (for multiplication)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "m = linopy.Model()\n", - "\n", - "time = pd.RangeIndex(5, name=\"time\")\n", - "x = m.add_variables(lower=0, coords=[time], name=\"x\")\n", - "\n", - "subset_time = pd.RangeIndex(3, name=\"time\")\n", - "y = m.add_variables(lower=0, coords=[subset_time], name=\"y\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Adding ``x`` (5 time steps) and ``y`` (3 time steps) gives an expression over all 5 time steps. Where ``y`` has no entry (time 3, 4), the coefficient is zero — i.e. ``y`` simply drops out of the sum at those positions." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x + y" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The same applies when multiplying by a constant that covers only a subset of coordinates. Missing positions get a coefficient of zero:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "factor = xr.DataArray([2, 3, 4], dims=[\"time\"], coords={\"time\": [0, 1, 2]})\n", - "x * factor" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Adding a constant subset also fills missing coordinates with zero:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x + factor" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Constraints with Subset RHS\n", - "\n", - "For constraints, missing right-hand-side values are filled with ``NaN``, which tells linopy to **skip** the constraint at those positions:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rhs = xr.DataArray([10, 20, 30], dims=[\"time\"], coords={\"time\": [0, 1, 2]})\n", - "con = x <= rhs\n", - "con" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The constraint only applies at time 0, 1, 2. At time 3 and 4 the RHS is ``NaN``, so no constraint is created." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "### Same-Shape Operands: Positional Alignment\n\nWhen two operands have the **same shape** on a shared dimension, linopy uses **positional alignment** by default — coordinate labels are ignored and the left operand's labels are kept. This is a performance optimization but can be surprising:" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "offset_const = xr.DataArray(\n", - " [10, 20, 30, 40, 50], dims=[\"time\"], coords={\"time\": [5, 6, 7, 8, 9]}\n", - ")\n", - "x + offset_const" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "Even though ``offset_const`` has coordinates ``[5, 6, 7, 8, 9]`` and ``x`` has ``[0, 1, 2, 3, 4]``, the result uses ``x``'s labels. The values are aligned by **position**, not by label. The same applies when adding two variables or expressions of identical shape:" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "z = m.add_variables(lower=0, coords=[pd.RangeIndex(5, 10, name=\"time\")], name=\"z\")\n", - "x + z" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "``x`` (time 0–4) and ``z`` (time 5–9) share no coordinate labels, yet the result has 5 entries under ``x``'s coordinates — because they have the same shape, positions are matched directly.\n\nTo force **label-based** alignment, pass an explicit ``join``:" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x.add(z, join=\"outer\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "With ``join=\"outer\"``, the result spans all 10 time steps (union of 0–4 and 5–9), filling missing positions with zeros. This is the correct label-based alignment. The same-shape positional shortcut is equivalent to ``join=\"override\"`` — see below." - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## The ``join`` Parameter\n", - "\n", - "For explicit control over alignment, use the ``.add()``, ``.sub()``, ``.mul()``, and ``.div()`` methods with a ``join`` parameter. The supported values follow xarray conventions:\n", - "\n", - "- ``\"inner\"`` — intersection of coordinates\n", - "- ``\"outer\"`` — union of coordinates (with fill)\n", - "- ``\"left\"`` — keep left operand's coordinates\n", - "- ``\"right\"`` — keep right operand's coordinates\n", - "- ``\"override\"`` — positional alignment, ignore coordinate labels\n", - "- ``\"exact\"`` — coordinates must match exactly (raises on mismatch)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "m2 = linopy.Model()\n", - "\n", - "i_a = pd.Index([0, 1, 2], name=\"i\")\n", - "i_b = pd.Index([1, 2, 3], name=\"i\")\n", - "\n", - "a = m2.add_variables(coords=[i_a], name=\"a\")\n", - "b = m2.add_variables(coords=[i_b], name=\"b\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Inner join** — only shared coordinates (i=1, 2):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a.add(b, join=\"inner\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Outer join** — union of coordinates (i=0, 1, 2, 3):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a.add(b, join=\"outer\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Left join** — keep left operand's coordinates (i=0, 1, 2):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a.add(b, join=\"left\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Right join** — keep right operand's coordinates (i=1, 2, 3):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a.add(b, join=\"right\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "**Override** — positional alignment, ignore coordinate labels. The result uses the left operand's coordinates. Here ``a`` has i=[0, 1, 2] and ``b`` has i=[1, 2, 3], so positions are matched as 0↔1, 1↔2, 2↔3:" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a.add(b, join=\"override\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Multiplication with ``join``\n", - "\n", - "The same ``join`` parameter works on ``.mul()`` and ``.div()``. When multiplying by a constant that covers a subset, ``join=\"inner\"`` restricts the result to shared coordinates only, while ``join=\"left\"`` fills missing values with zero:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "const = xr.DataArray([2, 3, 4], dims=[\"i\"], coords={\"i\": [1, 2, 3]})\n", - "\n", - "a.mul(const, join=\"inner\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a.mul(const, join=\"left\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Alignment in Constraints\n", - "\n", - "The ``.le()``, ``.ge()``, and ``.eq()`` methods create constraints with explicit coordinate alignment. They accept the same ``join`` parameter:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rhs = xr.DataArray([10, 20], dims=[\"i\"], coords={\"i\": [0, 1]})\n", - "\n", - "a.le(rhs, join=\"inner\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "With ``join=\"inner\"``, the constraint only exists at the intersection (i=0, 1). Compare with ``join=\"left\"``:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a.le(rhs, join=\"left\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "With ``join=\"left\"``, the result covers all of ``a``'s coordinates (i=0, 1, 2). At i=2, where the RHS has no value, the RHS becomes ``NaN`` and the constraint is masked out.\n", - "\n", - "The same methods work on expressions:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "expr = 2 * a + 1\n", - "expr.eq(rhs, join=\"inner\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "## Practical Example\n\nConsider a generation dispatch model where solar availability follows a daily profile and a minimum demand constraint only applies during peak hours." - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "m3 = linopy.Model()\n", - "\n", - "hours = pd.RangeIndex(24, name=\"hour\")\n", - "techs = pd.Index([\"solar\", \"wind\", \"gas\"], name=\"tech\")\n", - "\n", - "gen = m3.add_variables(lower=0, coords=[hours, techs], name=\"gen\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Capacity limits apply to all hours and techs — standard broadcasting handles this:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "capacity = xr.DataArray([100, 80, 50], dims=[\"tech\"], coords={\"tech\": techs})\n", - "m3.add_constraints(gen <= capacity, name=\"capacity_limit\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "For solar, we build a full 24-hour availability profile — zero at night, sine-shaped during daylight (hours 6–18). Since this covers all hours, standard alignment works directly and solar is properly constrained to zero at night:" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "solar_avail = np.zeros(24)\n", - "solar_avail[6:19] = 100 * np.sin(np.linspace(0, np.pi, 13))\n", - "solar_availability = xr.DataArray(solar_avail, dims=[\"hour\"], coords={\"hour\": hours})\n", - "\n", - "solar_gen = gen.sel(tech=\"solar\")\n", - "m3.add_constraints(solar_gen <= solar_availability, name=\"solar_avail\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "Now suppose a minimum demand of 120 MW must be met, but only during peak hours (8–20). The demand array covers a subset of hours, so we use ``join=\"inner\"`` to restrict the constraint to just those hours:" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "peak_hours = pd.RangeIndex(8, 21, name=\"hour\")\n", - "peak_demand = xr.DataArray(\n", - " np.full(len(peak_hours), 120.0), dims=[\"hour\"], coords={\"hour\": peak_hours}\n", - ")\n", - "\n", - "total_gen = gen.sum(\"tech\")\n", - "m3.add_constraints(total_gen.ge(peak_demand, join=\"inner\"), name=\"peak_demand\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "The demand constraint only applies during peak hours (8–20). Outside that range, no minimum generation is required." - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "| ``join`` | Coordinates | Fill behavior |\n", - "|----------|------------|---------------|\n", - "| ``None`` (default) | Auto-detect (keeps superset) | Zeros for arithmetic, NaN for constraint RHS |\n", - "| ``\"inner\"`` | Intersection only | No fill needed |\n", - "| ``\"outer\"`` | Union | Fill with operation identity (0 for add, 0 for mul) |\n", - "| ``\"left\"`` | Left operand's | Fill right with identity |\n", - "| ``\"right\"`` | Right operand's | Fill left with identity |\n", - "| ``\"override\"`` | Left operand's (positional) | Positional alignment, ignore labels |\n", - "| ``\"exact\"`` | Must match exactly | Raises error if different |" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/missing-data.ipynb b/examples/missing-data.ipynb new file mode 100644 index 00000000..e660783b --- /dev/null +++ b/examples/missing-data.ipynb @@ -0,0 +1,396 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro", + "metadata": {}, + "source": [ + "# Missing Data and Masking\n", + "\n", + "This notebook explains linopy's NaN convention under v1 and how to handle missing data.\n", + "\n", + "1. [The NaN convention](#the-nan-convention) — design principles\n", + "2. [What raises](#what-raises) — NaN at API boundaries\n", + "3. [Handling NaN with `.fillna()`](#handling-nan-with-fillna) — choosing the right fill value\n", + "4. [Masking constraints](#masking-constraints) — `.sel()` and `mask=`\n", + "5. [Masking with NaN in coefficients](#masking-with-nan-in-coefficients) — multi-dimensional patterns\n", + "6. [Legacy NaN behavior](#legacy-nan-behavior-for-comparison) — how it worked before\n", + "\n", + "For coordinate alignment rules, see [Arithmetic Convention](arithmetic-convention.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "imports", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:52:16.879309Z", + "start_time": "2026-03-11T14:52:16.087004Z" + }, + "execution": { + "iopub.execute_input": "2026-03-12T07:17:13.611630Z", + "iopub.status.busy": "2026-03-12T07:17:13.611383Z", + "iopub.status.idle": "2026-03-12T07:17:14.222456Z", + "shell.execute_reply": "2026-03-12T07:17:14.222237Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr\n", + "\n", + "import linopy\n", + "\n", + "linopy.options[\"arithmetic_convention\"] = \"v1\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "setup", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:52:16.939592Z", + "start_time": "2026-03-11T14:52:16.885073Z" + }, + "execution": { + "iopub.execute_input": "2026-03-12T07:17:14.223624Z", + "iopub.status.busy": "2026-03-12T07:17:14.223510Z", + "iopub.status.idle": "2026-03-12T07:17:14.252820Z", + "shell.execute_reply": "2026-03-12T07:17:14.252554Z" + } + }, + "outputs": [], + "source": [ + "m = linopy.Model()\n", + "time = pd.RangeIndex(5, name=\"time\")\n", + "x = m.add_variables(lower=0, coords=[time], name=\"x\")\n", + "\n", + "# Data with NaN\n", + "data = xr.DataArray([1.0, np.nan, 3.0, 4.0, 5.0], dims=[\"time\"], coords={\"time\": time})" + ] + }, + { + "cell_type": "markdown", + "id": "rqgv2f7nwpb", + "metadata": {}, + "source": "---\n\n## The NaN convention\n\nIn linopy v1, **NaN means \"absent term.\"** It is never a numeric value.\n\n### How NaN enters\n\nOnly two sources produce NaN inside linopy data structures:\n\n1. **`mask=` argument** at construction (`add_variables`, `add_constraints`) — you explicitly declare which slots exist.\n2. **Structural operations** that produce absent slots: `.shift()`, `.where()`, `.reindex()`, `.reindex_like()`, `.unstack()` (with missing combinations).\n\nOperations that do **not** produce NaN: `.roll()` (circular), `.sel()` / `.isel()` (subset), `.drop_sel()` (drops), `.expand_dims()` / `.broadcast_like()` (broadcast existing data).\n\n### How NaN propagates\n\nAn expression is a sum of terms. Each term has a coefficient, a variable reference, and the expression has a shared constant. NaN marks an **individual term** as absent — it does not mask the entire coordinate.\n\nWhen expressions are combined (e.g., `x*2 + y.shift(time=1)`), each term is kept independently. At time=0, `y.shift` contributes no term (NaN coeffs, vars=-1), but `x*2` is still valid. The result at time=0 is `2*x[0]` — not absent.\n\nA coordinate is only fully absent when **all** terms have vars=-1 **and** the constant is NaN. This is exactly what `isnull()` checks.\n\n### Where NaN lives\n\nNaN is burned directly into the float fields: `coeffs`, `const`, `rhs`, `lower`, `upper`. Integer fields (`labels`, `vars`) use **-1** as their equivalent sentinel. There is no separate boolean mask array.\n\n### What raises\n\nAny **user-supplied NaN at an API boundary** — in bounds, constants, factors, or RHS — raises `ValueError` immediately. Masking is always explicit via `mask=` or `.sel()`, never by passing NaN as a value.\n\n### Why this is consistent\n\n- **`lhs >= rhs` is `lhs - rhs >= 0`**, so RHS obeys the same rule as any constant — no special case.\n- **No dual role for NaN**: it cannot mean both \"absent\" and \"a number I computed with.\" Internal NaN (from `shift`, `mask=`) is always structural. User NaN is always an error.\n- **Absent terms, not absent coordinates**: combining a valid expression with a partially-absent one does not destroy the valid part. Only when *every* term at a coordinate is absent is the coordinate itself absent." + }, + { + "cell_type": "markdown", + "id": "v1-rule-header", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## What raises\n", + "\n", + "**NaN in any arithmetic operand raises `ValueError`.** This includes:\n", + "- Constants added/subtracted: `expr + data_with_nan`\n", + "- Factors multiplied/divided: `expr * data_with_nan`\n", + "- Constraint RHS: `expr >= data_with_nan` (because `expr >= rhs` is `expr - rhs >= 0`)\n", + "\n", + "There is no implicit fill. The library does not guess whether NaN means \"zero,\" \"exclude,\" or \"identity.\" You decide." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "v1-rule-demo", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:52:16.949756Z", + "start_time": "2026-03-11T14:52:16.942400Z" + }, + "execution": { + "iopub.execute_input": "2026-03-12T07:17:14.253991Z", + "iopub.status.busy": "2026-03-12T07:17:14.253892Z", + "iopub.status.idle": "2026-03-12T07:17:14.260195Z", + "shell.execute_reply": "2026-03-12T07:17:14.259998Z" + } + }, + "outputs": [], + "source": [ + "# All of these raise ValueError:\n", + "for op_name, op_fn in [\n", + " (\"add\", lambda: x + data),\n", + " (\"mul\", lambda: x * data),\n", + " (\"constraint\", lambda: x >= data),\n", + "]:\n", + " try:\n", + " op_fn()\n", + " except ValueError:\n", + " print(f\"{op_name}: ValueError raised (NaN in operand)\")" + ] + }, + { + "cell_type": "markdown", + "id": "fillna-header", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Handling NaN with `.fillna()`\n", + "\n", + "When your data contains NaN, fill it explicitly before combining with expressions. The fill value depends on what the NaN means in your context:\n", + "\n", + "| Operation | Fill value | Meaning |\n", + "|-----------|-----------|--------|\n", + "| `expr + data.fillna(0)` | 0 | NaN = \"no offset\" |\n", + "| `expr * data.fillna(0)` | 0 | NaN = \"exclude this term\" |\n", + "| `expr * data.fillna(1)` | 1 | NaN = \"no scaling\" |\n", + "| `expr / data.fillna(1)` | 1 | NaN = \"no scaling\" |\n", + "\n", + "The choice is yours — and that's the point. Under legacy, the library chose for you (0 for add/mul, 1 for div). Under v1, you make the decision explicit." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fillna-demo", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:52:16.968586Z", + "start_time": "2026-03-11T14:52:16.956299Z" + }, + "execution": { + "iopub.execute_input": "2026-03-12T07:17:14.261186Z", + "iopub.status.busy": "2026-03-12T07:17:14.261122Z", + "iopub.status.idle": "2026-03-12T07:17:14.270213Z", + "shell.execute_reply": "2026-03-12T07:17:14.269997Z" + } + }, + "outputs": [], + "source": [ + "# Fill NaN before operating — you choose the fill value\n", + "print(\"add fillna(0):\", (x + data.fillna(0)).const.values)\n", + "print(\"mul fillna(0):\", (x * data.fillna(0)).coeffs.squeeze().values)\n", + "print(\"mul fillna(1):\", (x * data.fillna(1)).coeffs.squeeze().values)" + ] + }, + { + "cell_type": "markdown", + "id": "masking-header", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Masking constraints\n", + "\n", + "A common pattern: your data has NaN at positions where no constraint should exist. For example, availability data that's only defined for certain hours, or cost data with missing entries.\n", + "\n", + "### Approach 1: `.sel()` (preferred)\n", + "\n", + "Select only valid positions — the constraint has fewer coordinates:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "masking-sel", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:52:16.983888Z", + "start_time": "2026-03-11T14:52:16.974378Z" + }, + "execution": { + "iopub.execute_input": "2026-03-12T07:17:14.271290Z", + "iopub.status.busy": "2026-03-12T07:17:14.271219Z", + "iopub.status.idle": "2026-03-12T07:17:14.279948Z", + "shell.execute_reply": "2026-03-12T07:17:14.279785Z" + } + }, + "outputs": [], + "source": [ + "# Availability data with NaN = \"no limit at this hour\"\n", + "availability = xr.DataArray(\n", + " [100.0, 80.0, np.nan, np.nan, 60.0], dims=[\"time\"], coords={\"time\": time}\n", + ")\n", + "\n", + "# Select only where data is valid — constraint has fewer coordinates\n", + "valid = availability.notnull()\n", + "m.add_constraints(x.sel(time=valid) <= availability.sel(time=valid), name=\"avail\")" + ] + }, + { + "cell_type": "markdown", + "id": "masking-mask-header", + "metadata": {}, + "source": [ + "No fillna, no mask parameter — the constraint simply doesn't exist at the NaN positions.\n", + "\n", + "### Approach 2: `mask=` parameter\n", + "\n", + "When `.sel()` is inconvenient (e.g., multi-dimensional data where NaN positions vary across dimensions), use `mask=`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "masking-mask-demo", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:52:16.998421Z", + "start_time": "2026-03-11T14:52:16.990226Z" + }, + "execution": { + "iopub.execute_input": "2026-03-12T07:17:14.280842Z", + "iopub.status.busy": "2026-03-12T07:17:14.280784Z", + "iopub.status.idle": "2026-03-12T07:17:14.286824Z", + "shell.execute_reply": "2026-03-12T07:17:14.286655Z" + } + }, + "outputs": [], + "source": [ + "# Same result using mask= instead of .sel()\n", + "mask = availability.notnull()\n", + "m.add_constraints(x <= availability.fillna(0), name=\"avail_masked\", mask=mask)" + ] + }, + { + "cell_type": "markdown", + "id": "masking-vars-header", + "metadata": {}, + "source": [ + "The same approaches work for variables with NaN bounds:\n", + "\n", + "```python\n", + "# With .sel()\n", + "valid = upper_bounds.notnull()\n", + "m.add_variables(upper=upper_bounds.sel(i=valid), coords=[valid_coords], name=\"y\")\n", + "\n", + "# Or with mask=\n", + "mask = upper_bounds.notnull()\n", + "m.add_variables(upper=upper_bounds.fillna(0), mask=mask, name=\"y\")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "coeff-header", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Masking with NaN in coefficients\n", + "\n", + "When NaN appears in coefficient data (e.g., efficiency factors where some combinations don't apply), the same two approaches work:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "coeff-demo", + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-11T14:52:17.017774Z", + "start_time": "2026-03-11T14:52:17.003374Z" + }, + "execution": { + "iopub.execute_input": "2026-03-12T07:17:14.287819Z", + "iopub.status.busy": "2026-03-12T07:17:14.287760Z", + "iopub.status.idle": "2026-03-12T07:17:14.300622Z", + "shell.execute_reply": "2026-03-12T07:17:14.300443Z" + } + }, + "outputs": [], + "source": [ + "# Efficiency data: solar has no efficiency at night (NaN)\n", + "techs = pd.Index([\"solar\", \"wind\"], name=\"tech\")\n", + "hours = pd.RangeIndex(4, name=\"hour\")\n", + "gen = m.add_variables(lower=0, coords=[hours, techs], name=\"gen\")\n", + "\n", + "efficiency = xr.DataArray(\n", + " [[np.nan, 0.35], [0.8, 0.35], [0.9, 0.35], [np.nan, 0.35]],\n", + " dims=[\"hour\", \"tech\"],\n", + " coords={\"hour\": hours, \"tech\": techs},\n", + ")\n", + "\n", + "# Approach 1: .sel() — select only valid hours per tech\n", + "valid_hours = efficiency.sel(tech=\"solar\").notnull()\n", + "solar_gen = gen.sel(tech=\"solar\", hour=valid_hours)\n", + "solar_eff = efficiency.sel(tech=\"solar\", hour=valid_hours)\n", + "print(\"sel approach — solar hours:\", solar_gen.coords[\"hour\"].values)\n", + "\n", + "# Approach 2: mask= — keep all coordinates, mask invalid ones\n", + "rhs = xr.DataArray([50.0] * 4, dims=[\"hour\"], coords={\"hour\": hours})\n", + "coeff_mask = efficiency.notnull().all(\"tech\")\n", + "expr = gen * efficiency.fillna(0)\n", + "m.add_constraints(expr >= rhs, name=\"min_output\", mask=coeff_mask)\n", + "print(\"mask approach — constraint mask:\", coeff_mask.values)" + ] + }, + { + "cell_type": "markdown", + "id": "legacy-header", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Legacy NaN behavior (for comparison)\n", + "\n", + "Under legacy, NaN was handled implicitly:\n", + "- **In arithmetic**: silently replaced with neutral elements (0 for add/sub/mul, 1 for div)\n", + "- **In constraint RHS**: NaN meant \"no constraint here\" — auto-masked internally\n", + "- **With `auto_mask=True`**: NaN in variable bounds meant \"no variable here\"\n", + "\n", + "This was convenient but could mask data quality issues. A NaN from a data pipeline bug would silently become 0, producing a valid but wrong model.\n", + "\n", + "### Migration\n", + "\n", + "| Legacy code (silent) | v1 equivalent (explicit) |\n", + "|---|---|\n", + "| `x + data_with_nans` | `x + data_with_nans.fillna(0)` |\n", + "| `x * data_with_nans` | `x * data_with_nans.fillna(0)` |\n", + "| `x / data_with_nans` | `x / data_with_nans.fillna(1)` |\n", + "| `m.add_constraints(expr >= nan_rhs)` | `m.add_constraints(expr.sel(...) >= rhs.sel(...))` |\n", + "| `Model(auto_mask=True)` | Explicit `mask=` or `.sel()` |" + ] + }, + { + "cell_type": "markdown", + "id": "summary", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Summary\n", + "\n", + "| Aspect | v1 | Legacy |\n", + "|---|---|---|\n", + "| **NaN means** | Absent term (not absent coordinate) | Numeric placeholder (filled silently) |\n", + "| **NaN sources** | `mask=`, structural ops only | Anywhere (user data, bounds, RHS) |\n", + "| **NaN in operands** | `ValueError` | Filled with neutral element (0 or 1) |\n", + "| **NaN in constraint RHS** | `ValueError` | Auto-masked |\n", + "| **Combining expressions** | Absent terms ignored, valid terms kept | NaN filled before combining |\n", + "| **Coordinate absent when** | All terms absent AND const is NaN | Never (NaN always filled) |\n", + "| **Masking** | Explicit: `.sel()` or `mask=` | Implicit via NaN / `auto_mask` |\n", + "| **Storage** | Float fields + `-1` sentinels | Same, but NaN has dual role |\n", + "| **Fill value choice** | User decides | Library decides |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/linopy/__init__.py b/linopy/__init__.py index b1dc33b9..a372c087 100644 --- a/linopy/__init__.py +++ b/linopy/__init__.py @@ -13,7 +13,7 @@ # we need to extend their __mul__ functions with a quick special case import linopy.monkey_patch_xarray # noqa: F401 from linopy.common import align -from linopy.config import options +from linopy.config import LinopyDeprecationWarning, options from linopy.constants import EQUAL, GREATER_EQUAL, LESS_EQUAL from linopy.constraints import Constraint, Constraints from linopy.expressions import LinearExpression, QuadraticExpression, merge @@ -34,6 +34,7 @@ "EQUAL", "GREATER_EQUAL", "LESS_EQUAL", + "LinopyDeprecationWarning", "LinearExpression", "Model", "Objective", diff --git a/linopy/common.py b/linopy/common.py index 09f67355..a1022189 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -205,6 +205,9 @@ def numpy_to_dataarray( if isinstance(coords, list): coords = dict(zip(dims, coords[: arr.ndim])) elif is_dict_like(coords): + # Filter coords to matching dims — this is expected when a + # lower-dimensional constant is broadcast against an expression + # whose full coords are passed through as_dataarray. coords = {k: v for k, v in coords.items() if k in dims} return DataArray(arr, coords=coords, dims=dims, **kwargs) @@ -1205,7 +1208,7 @@ def check_common_keys_values(list_of_dicts: list[dict[str, Any]]) -> bool: def align( *objects: LinearExpression | QuadraticExpression | Variable | T_Alignable, - join: JoinOptions = "inner", + join: JoinOptions | None = None, copy: bool = True, indexes: Any = None, exclude: str | Iterable[Hashable] = frozenset(), @@ -1265,9 +1268,26 @@ def align( """ + from linopy.config import options from linopy.expressions import LinearExpression, QuadraticExpression from linopy.variables import Variable + if join is None: + join = options["arithmetic_convention"] + + if join == "legacy": + from linopy.config import LEGACY_DEPRECATION_MESSAGE, LinopyDeprecationWarning + + warn( + LEGACY_DEPRECATION_MESSAGE, + LinopyDeprecationWarning, + stacklevel=2, + ) + join = "inner" + + elif join == "v1": + join = "exact" + finisher: list[partial[Any] | Callable[[Any], Any]] = [] das: list[Any] = [] for obj in objects: diff --git a/linopy/config.py b/linopy/config.py index c098709d..9f04ce17 100644 --- a/linopy/config.py +++ b/linopy/config.py @@ -9,28 +9,46 @@ from typing import Any +VALID_ARITHMETIC_JOINS = {"legacy", "v1"} + +LEGACY_DEPRECATION_MESSAGE = ( + "The 'legacy' arithmetic convention is deprecated and will be removed in " + "linopy v1. Set linopy.options['arithmetic_convention'] = 'v1' to opt in " + "to the new behavior, or filter this warning with:\n" + " import warnings; warnings.filterwarnings('ignore', category=LinopyDeprecationWarning)" +) + + +class LinopyDeprecationWarning(FutureWarning): + """Warning for deprecated linopy features scheduled for removal.""" + class OptionSettings: - def __init__(self, **kwargs: int) -> None: + def __init__(self, **kwargs: Any) -> None: self._defaults = kwargs self._current_values = kwargs.copy() - def __call__(self, **kwargs: int) -> None: + def __call__(self, **kwargs: Any) -> None: self.set_value(**kwargs) - def __getitem__(self, key: str) -> int: + def __getitem__(self, key: str) -> Any: return self.get_value(key) - def __setitem__(self, key: str, value: int) -> None: + def __setitem__(self, key: str, value: Any) -> None: return self.set_value(**{key: value}) - def set_value(self, **kwargs: int) -> None: + def set_value(self, **kwargs: Any) -> None: for k, v in kwargs.items(): if k not in self._defaults: raise KeyError(f"{k} is not a valid setting.") + if k == "arithmetic_convention" and v not in VALID_ARITHMETIC_JOINS: + raise ValueError( + f"Invalid arithmetic_convention: {v!r}. " + f"Must be one of {VALID_ARITHMETIC_JOINS}." + ) self._current_values[k] = v - def get_value(self, name: str) -> int: + def get_value(self, name: str) -> Any: if name in self._defaults: return self._current_values[name] else: @@ -57,4 +75,8 @@ def __repr__(self) -> str: return f"OptionSettings:\n {settings}" -options = OptionSettings(display_max_rows=14, display_max_terms=6) +options = OptionSettings( + display_max_rows=14, + display_max_terms=6, + arithmetic_convention="legacy", +) diff --git a/linopy/expressions.py b/linopy/expressions.py index d2ae9022..d09916e7 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -30,8 +30,15 @@ from xarray import Coordinates, DataArray, Dataset, IndexVariable from xarray.core.coordinates import DataArrayCoordinates, DatasetCoordinates from xarray.core.indexes import Indexes +from xarray.core.types import JoinOptions from xarray.core.utils import Frozen +try: + from xarray.structure.alignment import AlignmentError +except ImportError: + # Fallback for older xarray versions where this isn't a separate class + AlignmentError = ValueError # type: ignore[assignment, misc] + try: # resolve breaking change in xarray 2025.03.0 import xarray.computation.rolling @@ -48,7 +55,6 @@ LocIndexer, as_dataarray, assign_multiindex_safe, - check_common_keys_values, check_has_nulls, check_has_nulls_polars, fill_missing_coords, @@ -67,7 +73,7 @@ to_dataframe, to_polars, ) -from linopy.config import options +from linopy.config import LEGACY_DEPRECATION_MESSAGE, LinopyDeprecationWarning, options from linopy.constants import ( CV_DIM, EQUAL, @@ -90,6 +96,8 @@ ) if TYPE_CHECKING: + from typing_extensions import Self + from linopy.constraints import AnonymousScalarConstraint, Constraint from linopy.model import Model from linopy.piecewise import PiecewiseConstraintDescriptor, PiecewiseExpression @@ -548,7 +556,7 @@ def _align_constant( self: GenericExpression, other: DataArray, fill_value: float = 0, - join: str | None = None, + join: JoinOptions | None = None, ) -> tuple[DataArray, DataArray, bool]: """ Align a constant DataArray with self.const. @@ -560,7 +568,7 @@ def _align_constant( fill_value : float, default: 0 Fill value for missing coordinates. join : str, optional - Alignment method. If None, uses size-aware default behavior. + Alignment method. If None, uses ``options["arithmetic_convention"]``. Returns ------- @@ -572,6 +580,16 @@ def _align_constant( Whether the expression's data needs reindexing. """ if join is None: + join = options["arithmetic_convention"] + + if join == "legacy": + # stacklevel=4: user code -> __add__/__mul__ -> _add_constant/_apply_constant_op -> _align_constant + warn( + LEGACY_DEPRECATION_MESSAGE, + LinopyDeprecationWarning, + stacklevel=4, + ) + # Old behavior: override when same sizes, left join otherwise if other.sizes == self.const.sizes: return self.const, other.assign_coords(coords=self.coords), False return ( @@ -579,33 +597,68 @@ def _align_constant( other.reindex_like(self.const, fill_value=fill_value), False, ) - elif join == "override": + + elif join == "v1": + join = "exact" + + if join == "override": return self.const, other.assign_coords(coords=self.coords), False - else: - self_const, aligned = xr.align( + elif join == "left": + return ( self.const, - other, - join=join, - fill_value=fill_value, # type: ignore[call-overload] + other.reindex_like(self.const, fill_value=fill_value), + False, ) + else: + try: + self_const, aligned = xr.align( + self.const, other, join=join, fill_value=fill_value + ) + except ValueError as e: + if "exact" in str(e): + raise ValueError( + f"{e}\n" + "Use .add()/.sub()/.mul()/.div() with an explicit join= parameter:\n" + ' .add(other, join="inner") # intersection of coordinates\n' + ' .add(other, join="outer") # union of coordinates (with fill)\n' + ' .add(other, join="left") # keep left operand\'s coordinates\n' + ' .add(other, join="override") # positional alignment' + ) from None + raise return self_const, aligned, True def _add_constant( - self: GenericExpression, other: ConstantLike, join: str | None = None + self: GenericExpression, other: ConstantLike, join: JoinOptions | None = None ) -> GenericExpression: - # NaN values in self.const or other are filled with 0 (additive identity) - # so that missing data does not silently propagate through arithmetic. + is_legacy = ( + join is None and options["arithmetic_convention"] == "legacy" + ) or join == "legacy" if np.isscalar(other) and join is None: - return self.assign(const=self.const.fillna(0) + other) + if not is_legacy and np.isnan(other): + raise ValueError( + "Constant contains NaN values. Use .fillna() to handle " + "missing values before arithmetic operations." + ) + const = self.const.fillna(0) + other + return self.assign(const=const) da = as_dataarray(other, coords=self.coords, dims=self.coord_dims) self_const, da, needs_data_reindex = self._align_constant( da, fill_value=0, join=join ) - da = da.fillna(0) + # Always fill self_const with 0 (additive identity) to stay + # consistent with merge() and preserve associativity. self_const = self_const.fillna(0) + if is_legacy: + da = da.fillna(0) + elif da.isnull().any(): + raise ValueError( + "Constant contains NaN values. Use .fillna() to handle " + "missing values before arithmetic operations." + ) if needs_data_reindex: + fv = {**self._fill_value, "const": 0} return self.__class__( - self.data.reindex_like(self_const, fill_value=self._fill_value).assign( + self.data.reindex_like(self_const, fill_value=fv).assign( const=self_const + da ), self.model, @@ -617,40 +670,56 @@ def _apply_constant_op( other: ConstantLike, op: Callable[[DataArray, DataArray], DataArray], fill_value: float, - join: str | None = None, + join: JoinOptions | None = None, ) -> GenericExpression: - """ - Apply a constant operation (mul, div, etc.) to this expression with a scalar or array. - - NaN values are filled with neutral elements before the operation: - - factor (other) is filled with fill_value (0 for mul, 1 for div) - - coeffs and const are filled with 0 (additive identity) - """ + is_legacy = ( + join is None and options["arithmetic_convention"] == "legacy" + ) or join == "legacy" + # Fast path for scalars: no dimensions to align + if np.isscalar(other): + if not is_legacy and np.isnan(other): + raise ValueError( + "Factor contains NaN values. Use .fillna() to handle " + "missing values before arithmetic operations." + ) + coeffs = self.coeffs.fillna(0) if is_legacy else self.coeffs + const = self.const.fillna(0) if is_legacy else self.const + scalar = DataArray(other) + return self.assign(coeffs=op(coeffs, scalar), const=op(const, scalar)) factor = as_dataarray(other, coords=self.coords, dims=self.coord_dims) self_const, factor, needs_data_reindex = self._align_constant( factor, fill_value=fill_value, join=join ) - factor = factor.fillna(fill_value) - self_const = self_const.fillna(0) + if is_legacy: + factor = factor.fillna(fill_value) + self_const = self_const.fillna(0) + elif factor.isnull().any(): + raise ValueError( + "Factor contains NaN values. Use .fillna() to handle " + "missing values before arithmetic operations." + ) if needs_data_reindex: - data = self.data.reindex_like(self_const, fill_value=self._fill_value) - coeffs = data.coeffs.fillna(0) + fv = {**self._fill_value, "const": 0} + data = self.data.reindex_like(self_const, fill_value=fv) + coeffs = data.coeffs.fillna(0) if is_legacy else data.coeffs return self.__class__( assign_multiindex_safe( - data, coeffs=op(coeffs, factor), const=op(self_const, factor) + data, + coeffs=op(coeffs, factor), + const=op(self_const, factor), ), self.model, ) - coeffs = self.coeffs.fillna(0) + coeffs = self.coeffs.fillna(0) if is_legacy else self.coeffs return self.assign(coeffs=op(coeffs, factor), const=op(self_const, factor)) def _multiply_by_constant( - self: GenericExpression, other: ConstantLike, join: str | None = None + self: GenericExpression, other: ConstantLike, join: JoinOptions | None = None ) -> GenericExpression: return self._apply_constant_op(other, operator.mul, fill_value=0, join=join) def _divide_by_constant( - self: GenericExpression, other: ConstantLike, join: str | None = None + self: GenericExpression, other: ConstantLike, join: JoinOptions | None = None ) -> GenericExpression: return self._apply_constant_op(other, operator.truediv, fill_value=1, join=join) @@ -659,7 +728,7 @@ def __div__(self: GenericExpression, other: SideLike) -> GenericExpression: if isinstance(other, SUPPORTED_EXPRESSION_TYPES): raise TypeError( "unsupported operand type(s) for /: " - f"{type(self)} and {type(other)}" + f"{type(self)} and {type(other)}. " "Non-linear expressions are not yet supported." ) return self._divide_by_constant(other) @@ -718,7 +787,8 @@ def __lt__(self, other: Any) -> NotImplementedType: def add( self: GenericExpression, other: SideLike, - join: str | None = None, + join: JoinOptions | None = None, + fill_value: float | None = None, ) -> GenericExpression | QuadraticExpression: """ Add an expression to others. @@ -731,22 +801,27 @@ def add( How to align coordinates. One of "outer", "inner", "left", "right", "exact", "override". When None (default), uses the current default behavior. + fill_value : float, optional + Fill NaN in the expression's constant before adding. Useful + for reviving absent slots with a defined value. """ + expr = self.fillna(fill_value) if fill_value is not None else self if join is None: - return self.__add__(other) + return expr.__add__(other) if isinstance(other, SUPPORTED_CONSTANT_TYPES): - return self._add_constant(other, join=join) + return expr._add_constant(other, join=join) other = as_expression(other, model=self.model, dims=self.coord_dims) if isinstance(other, LinearExpression) and isinstance( - self, QuadraticExpression + expr, QuadraticExpression ): other = other.to_quadexpr() - return merge([self, other], cls=self.__class__, join=join) # type: ignore[list-item] + return merge([expr, other], cls=self.__class__, join=join) # type: ignore[list-item] def sub( self: GenericExpression, other: SideLike, - join: str | None = None, + join: JoinOptions | None = None, + fill_value: float | None = None, ) -> GenericExpression | QuadraticExpression: """ Subtract others from expression. @@ -759,13 +834,16 @@ def sub( How to align coordinates. One of "outer", "inner", "left", "right", "exact", "override". When None (default), uses the current default behavior. + fill_value : float, optional + Fill NaN in the expression's constant before subtracting. """ - return self.add(-other, join=join) + return self.add(-other, join=join, fill_value=fill_value) def mul( self: GenericExpression, other: SideLike, - join: str | None = None, + join: JoinOptions | None = None, + fill_value: float | None = None, ) -> GenericExpression | QuadraticExpression: """ Multiply the expr by a factor. @@ -778,19 +856,23 @@ def mul( How to align coordinates. One of "outer", "inner", "left", "right", "exact", "override". When None (default), uses the current default behavior. + fill_value : float, optional + Fill NaN in the expression's constant before multiplying. """ + expr = self.fillna(fill_value) if fill_value is not None else self if join is None: - return self.__mul__(other) + return expr.__mul__(other) if isinstance(other, SUPPORTED_EXPRESSION_TYPES): raise TypeError( "join parameter is not supported for expression-expression multiplication" ) - return self._multiply_by_constant(other, join=join) + return expr._multiply_by_constant(other, join=join) def div( self: GenericExpression, other: VariableLike | ConstantLike, - join: str | None = None, + join: JoinOptions | None = None, + fill_value: float | None = None, ) -> GenericExpression | QuadraticExpression: """ Divide the expr by a factor. @@ -803,21 +885,24 @@ def div( How to align coordinates. One of "outer", "inner", "left", "right", "exact", "override". When None (default), uses the current default behavior. + fill_value : float, optional + Fill NaN in the expression's constant before dividing. """ + expr = self.fillna(fill_value) if fill_value is not None else self if join is None: - return self.__div__(other) + return expr.__div__(other) if isinstance(other, SUPPORTED_EXPRESSION_TYPES): raise TypeError( "unsupported operand type(s) for /: " f"{type(self)} and {type(other)}. " "Non-linear expressions are not yet supported." ) - return self._divide_by_constant(other, join=join) + return expr._divide_by_constant(other, join=join) def le( self: GenericExpression, rhs: SideLike, - join: str | None = None, + join: JoinOptions | None = None, ) -> Constraint: """ Less than or equal constraint. @@ -836,7 +921,7 @@ def le( def ge( self: GenericExpression, rhs: SideLike, - join: str | None = None, + join: JoinOptions | None = None, ) -> Constraint: """ Greater than or equal constraint. @@ -855,7 +940,7 @@ def ge( def eq( self: GenericExpression, rhs: SideLike, - join: str | None = None, + join: JoinOptions | None = None, ) -> Constraint: """ Equality constraint. @@ -1111,7 +1196,7 @@ def cumsum( return self.rolling(dim=dim_dict).sum(keep_attrs=keep_attrs, skipna=skipna) def to_constraint( - self, sign: SignLike, rhs: SideLike, join: str | None = None + self, sign: SignLike, rhs: SideLike, join: JoinOptions | None = None ) -> Constraint: """ Convert a linear expression to a constraint. @@ -1141,34 +1226,85 @@ def to_constraint( f"Both sides of the constraint are constant. At least one side must contain variables. {self} {rhs}" ) - if isinstance(rhs, SUPPORTED_CONSTANT_TYPES): - rhs = as_dataarray(rhs, coords=self.coords, dims=self.coord_dims) + effective_join = join if join is not None else options["arithmetic_convention"] - extra_dims = set(rhs.dims) - set(self.coord_dims) - if extra_dims: - logger.warning( - f"Constant RHS contains dimensions {extra_dims} not present " - f"in the expression, which might lead to inefficiencies. " - f"Consider collapsing the dimensions by taking min/max." + if effective_join == "legacy": + warn( + LEGACY_DEPRECATION_MESSAGE, + LinopyDeprecationWarning, + stacklevel=3, + ) + # Old behavior: convert to DataArray, warn about extra dims, + # reindex_like (left join), then sub + if isinstance(rhs, SUPPORTED_CONSTANT_TYPES): + rhs = as_dataarray(rhs, coords=self.coords, dims=self.coord_dims) + extra_dims = set(rhs.dims) - set(self.coord_dims) + if extra_dims: + logger.warning( + f"Constant RHS contains dimensions {extra_dims} not present " + f"in the expression, which might lead to inefficiencies. " + f"Consider collapsing the dimensions by taking min/max." + ) + rhs = rhs.reindex_like(self.const, fill_value=np.nan) + # Alignment already done — compute constraint directly + constraint_rhs = rhs - self.const + data = assign_multiindex_safe( + self.data[["coeffs", "vars"]], sign=sign, rhs=constraint_rhs ) - rhs = rhs.reindex_like(self.const, fill_value=np.nan) + return constraints.Constraint(data, model=self.model) + # Non-constant rhs (Variable/Expression) — fall through to sub path + + if effective_join == "v1": + effective_join = "exact" + + if isinstance(rhs, SUPPORTED_CONSTANT_TYPES) and not isinstance(rhs, DataArray): + rhs = as_dataarray(rhs, coords=self.coords, dims=self.coord_dims) - # Remember where RHS is NaN (meaning "no constraint") before the - # subtraction, which may fill NaN with 0 as part of normal - # expression arithmetic. if isinstance(rhs, DataArray): - rhs_nan_mask = rhs.isnull() - else: - rhs_nan_mask = None + is_legacy = ( + join is None and options["arithmetic_convention"] == "legacy" + ) or join == "legacy" + if not is_legacy and rhs.isnull().any(): + raise ValueError( + "Constraint RHS contains NaN values. Use .fillna() and " + "mask= to handle missing values explicitly." + ) + if effective_join == "override": + aligned_rhs = rhs.assign_coords(coords=self.const.coords) + expr_const = self.const + expr_data = self.data + elif effective_join == "left": + aligned_rhs = rhs.reindex_like(self.const, fill_value=np.nan) + expr_const = self.const + expr_data = self.data + else: + try: + expr_const_aligned, aligned_rhs = xr.align( + self.const, rhs, join=effective_join, fill_value=np.nan + ) + except ValueError as e: + if "exact" in str(e): + raise ValueError( + f"{e}\n" + "Use .le()/.ge()/.eq() with an explicit join= parameter:\n" + ' .le(rhs, join="inner") # intersection of coordinates\n' + ' .le(rhs, join="left") # keep expression coordinates (NaN fill)\n' + ' .le(rhs, join="override") # positional alignment' + ) from None + raise + expr_const = expr_const_aligned.fillna(0) + expr_data = self.data.reindex_like( + expr_const_aligned, fill_value=self._fill_value + ) + constraint_rhs = aligned_rhs - expr_const + data = assign_multiindex_safe( + expr_data[["coeffs", "vars"]], sign=sign, rhs=constraint_rhs + ) + return constraints.Constraint(data, model=self.model) all_to_lhs = self.sub(rhs, join=join).data computed_rhs = -all_to_lhs.const - # Restore NaN at positions where the original constant RHS had no - # value so that downstream code still treats them as unconstrained. - if rhs_nan_mask is not None and rhs_nan_mask.any(): - computed_rhs = xr.where(rhs_nan_mask, np.nan, computed_rhs) - data = assign_multiindex_safe( all_to_lhs[["coeffs", "vars"]], sign=sign, rhs=computed_rhs ) @@ -1517,9 +1653,47 @@ def _sum( set_index = exprwrap(Dataset.set_index) - reindex = exprwrap(Dataset.reindex, fill_value=_fill_value) + def reindex( + self, + indexers: Mapping[Any, Any] | None = None, + fill_value: float = np.nan, + **indexers_kwargs: Any, + ) -> Self: + """ + Reindex the expression. + + ``fill_value`` sets the constant for missing coordinates (default NaN). + Variable labels and coefficients always use sentinel values + (vars=-1, coeffs=NaN). + """ + fv = {**self._fill_value, "const": fill_value} + return self.__class__( + self.data.reindex(indexers, fill_value=fv, **indexers_kwargs), self.model + ) + + def reindex_like( + self, + other: Any, + fill_value: float = np.nan, + **kwargs: Any, + ) -> Self: + """ + Reindex like another object. - reindex_like = exprwrap(Dataset.reindex_like, fill_value=_fill_value) + ``fill_value`` sets the constant for missing coordinates (default NaN). + Variable labels and coefficients always use sentinel values. + """ + fv = {**self._fill_value, "const": fill_value} + if isinstance(other, DataArray): + ref = other.to_dataset(name="__tmp__") + elif isinstance(other, Dataset): + ref = other + else: + ref = other.data + return self.__class__( + self.data.reindex_like(ref, fill_value=fv, **kwargs), + self.model, + ) rename = exprwrap(Dataset.rename) @@ -2208,7 +2382,7 @@ def solution(self) -> DataArray: return sol.rename("solution") def to_constraint( - self, sign: SignLike, rhs: SideLike, join: str | None = None + self, sign: SignLike, rhs: SideLike, join: JoinOptions | None = None ) -> NotImplementedType: raise NotImplementedError( "Quadratic expressions cannot be used in constraints." @@ -2341,17 +2515,29 @@ def merge( ], dim: str = TERM_DIM, cls: type[GenericExpression] = None, # type: ignore - join: str | None = None, + join: JoinOptions | None = None, **kwargs: Any, ) -> GenericExpression: """ - Merge multiple expression together. + Merge multiple expressions together. + + Concatenates expressions along a given dimension (default: ``_term``). + Faster than summing expressions individually. + + Join behavior by convention (when ``join=None``): + + - **v1**: Enforces exact match on shared user-dimension coordinates. + Helper dims (``_term``, ``_factor``) and the concat dim are excluded + from this check. Raises ``ValueError`` on mismatch. The actual + ``xr.concat`` uses ``join="outer"`` since helper dims legitimately + differ between expressions. + - **legacy**: Uses ``join="override"`` (positional alignment) when all + shared user dims have matching sizes, ``join="outer"`` otherwise. + - **explicit** (e.g. ``join="inner"``): Passed through to ``xr.concat``. + + Internal callers that bypass the convention: - This function is a bit faster than summing over multiple linear expressions. - In case a list of LinearExpression with exactly the same shape is passed - and the dimension to concatenate on is TERM_DIM, the concatenation uses - the coordinates of the first object as a basis which overrides the - coordinates of the consecutive objects. + - ``.add(join=X)``: passes explicit join through. Parameters ---------- @@ -2360,21 +2546,20 @@ def merge( dim : str Dimension along which the expressions should be concatenated. cls : type - Explicitly set the type of the resulting expression (So that the type checker will know the return type) + Explicitly set the type of the resulting expression (So that the + type checker will know the return type) join : str, optional How to align coordinates. One of "outer", "inner", "left", "right", - "exact", "override". When None (default), auto-detects based on - expression shapes. + "exact", "override". When None (default), uses the current + arithmetic convention. **kwargs - Additional keyword arguments passed to xarray.concat. Defaults to - {coords: "minimal", compat: "override"} or, in the special case described - above, to {coords: "minimal", compat: "override", "join": "override"}. + Additional keyword arguments passed to xarray.concat. Returns ------- res : linopy.LinearExpression or linopy.QuadraticExpression """ - if not isinstance(exprs, list) and len(add_exprs): + if not isinstance(exprs, list) and len(add_exprs) > 0: warn( "Passing a tuple to the merge function is deprecated. Please pass a list of objects to be merged", DeprecationWarning, @@ -2399,16 +2584,6 @@ def merge( model = exprs[0].model - if join is not None: - override = join == "override" - elif cls in linopy_types and dim in HELPER_DIMS: - coord_dims = [ - {k: v for k, v in e.sizes.items() if k not in HELPER_DIMS} for e in exprs - ] - override = check_common_keys_values(coord_dims) # type: ignore - else: - override = False - data = [e.data if isinstance(e, linopy_types) else e for e in exprs] data = [fill_missing_coords(ds, fill_helper_dims=True) for ds in data] @@ -2422,17 +2597,69 @@ def merge( elif cls == variables.Variable: kwargs["fill_value"] = variables.FILL_VALUE - if join is not None: - kwargs["join"] = join - elif override: - kwargs["join"] = "override" + effective_join = join if join is not None else options["arithmetic_convention"] + + if effective_join == "legacy": + warn( + LEGACY_DEPRECATION_MESSAGE, + LinopyDeprecationWarning, + stacklevel=2, + ) + # Reproduce old behavior: override when all shared dims have + # matching sizes, outer otherwise. + if cls in linopy_types and dim in HELPER_DIMS: + coord_dims = [ + {k: v for k, v in e.sizes.items() if k not in HELPER_DIMS} + for e in exprs + ] + common_keys = set.intersection(*(set(d.keys()) for d in coord_dims)) + override = all( + len({d[k] for d in coord_dims if k in d}) == 1 for k in common_keys + ) + else: + override = False + + kwargs["join"] = "override" if override else "outer" + elif effective_join == "v1": + # Enforce exact alignment on user dims only. Helper dims + # (_term, _factor) legitimately differ between expressions, + # so we can't pass join="exact" to xr.concat directly. + # Instead: pre-validate user dims, then concat with outer. + # Check only dimension-coordinates (not scalar coords left + # from .sel()), excluding helper dims and the concat dim. + skip_dims = set(HELPER_DIMS) | {dim} + user_coords = [ + {k: d.coords[k] for k in d.dims if k not in skip_dims} for d in data + ] + # Only check dims shared by all datasets (broadcasting is OK) + shared_dims = set.intersection(*(set(c.keys()) for c in user_coords)) + for d_name in shared_dims: + ref = user_coords[0][d_name] + for i, uc in enumerate(user_coords[1:], 1): + if not ref.equals(uc[d_name]): + raise ValueError( + f"Coordinate mismatch on dimension '{d_name}'.\n" + "Use .add()/.sub() with an explicit join= parameter:\n" + ' .add(other, join="inner") # intersection of coordinates\n' + ' .add(other, join="outer") # union of coordinates (with fill)\n' + ' .add(other, join="left") # keep left operand\'s coordinates' + ) + kwargs["join"] = "outer" else: - kwargs.setdefault("join", "outer") + # Explicit join passed through (e.g., from .add(join="inner")) + kwargs["join"] = effective_join if dim == TERM_DIM: ds = xr.concat([d[["coeffs", "vars"]] for d in data], dim, **kwargs) + # Concat without fill to detect where all constants were NaN + raw_consts = xr.concat([d["const"] for d in data], dim, **kwargs) + all_const_nan = raw_consts.isnull().all(TERM_DIM) + # Sum with fill_value=0 so valid NaN + valid 5 = 5 (not NaN) subkwargs = {**kwargs, "fill_value": 0} const = xr.concat([d["const"] for d in data], dim, **subkwargs).sum(TERM_DIM) + # Restore NaN where all input constants were NaN (all terms absent) + if all_const_nan.any(): + const = const.where(~all_const_nan) ds = assign_multiindex_safe(ds, const=const) elif dim == FACTOR_DIM: ds = xr.concat([d[["vars"]] for d in data], dim, **kwargs) diff --git a/linopy/piecewise.py b/linopy/piecewise.py index 78f7be65..489a8bdf 100644 --- a/linopy/piecewise.py +++ b/linopy/piecewise.py @@ -762,10 +762,10 @@ def _add_pwl_sos2_core( lambda_var.sum(dim=BREAKPOINT_DIM) == rhs, name=convex_name ) - x_weighted = (lambda_var * x_points).sum(dim=BREAKPOINT_DIM) + x_weighted = (lambda_var * x_points.fillna(0)).sum(dim=BREAKPOINT_DIM) model.add_constraints(x_expr == x_weighted, name=x_link_name) - y_weighted = (lambda_var * y_points).sum(dim=BREAKPOINT_DIM) + y_weighted = (lambda_var * y_points.fillna(0)).sum(dim=BREAKPOINT_DIM) model.add_constraints(target_expr == y_weighted, name=y_link_name) return convex_con @@ -851,13 +851,20 @@ def _add_pwl_incremental_core( delta_lo = delta_var.isel({LP_SEG_DIM: slice(None, -1)}, drop=True) delta_hi = delta_var.isel({LP_SEG_DIM: slice(1, None)}, drop=True) # Keep existing fill constraint as LP relaxation tightener + # Align coords for positional comparison (lo=[0..n-2], hi=[1..n-1]) + delta_hi = delta_hi.assign_coords( + {LP_SEG_DIM: delta_lo.coords[LP_SEG_DIM].values} + ) fill_con = model.add_constraints(delta_hi <= delta_lo, name=fill_name) binary_hi = binary_var.isel({LP_SEG_DIM: slice(1, None)}, drop=True) + binary_hi = binary_hi.assign_coords( + {LP_SEG_DIM: delta_lo.coords[LP_SEG_DIM].values} + ) model.add_constraints(binary_hi <= delta_lo, name=inc_order_name) - x0 = x_points.isel({BREAKPOINT_DIM: 0}) - y0 = y_points.isel({BREAKPOINT_DIM: 0}) + x0 = x_points.isel({BREAKPOINT_DIM: 0}, drop=True) + y0 = y_points.isel({BREAKPOINT_DIM: 0}, drop=True) # When active is provided, multiply base terms by active x_base: DataArray | LinearExpression = x0 diff --git a/linopy/solvers.py b/linopy/solvers.py index 474459fe..73eb9339 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -1542,6 +1542,7 @@ def solve_problem_from_file( condition = m.getStatus() termination_condition = CONDITION_MAP.get(condition, condition) + assert termination_condition is not None status = Status.from_termination_condition(termination_condition) status.legacy_status = condition diff --git a/linopy/sos_reformulation.py b/linopy/sos_reformulation.py index 8ccb7613..8e3e0330 100644 --- a/linopy/sos_reformulation.py +++ b/linopy/sos_reformulation.py @@ -182,7 +182,8 @@ def reformulate_sos2( added_constraints = [first_name] model.add_constraints( - x_expr.isel({sos_dim: 0}) <= M.isel({sos_dim: 0}) * z_expr.isel({sos_dim: 0}), + x_expr.isel({sos_dim: 0}, drop=True) + <= M.isel({sos_dim: 0}, drop=True) * z_expr.isel({sos_dim: 0}, drop=True), name=first_name, ) @@ -208,8 +209,9 @@ def reformulate_sos2( added_constraints.append(mid_name) model.add_constraints( - x_expr.isel({sos_dim: n - 1}) - <= M.isel({sos_dim: n - 1}) * z_expr.isel({sos_dim: n - 2}), + x_expr.isel({sos_dim: n - 1}, drop=True) + <= M.isel({sos_dim: n - 1}, drop=True) + * z_expr.isel({sos_dim: n - 2}, drop=True), name=last_name, ) added_constraints.extend([last_name, card_name]) diff --git a/linopy/variables.py b/linopy/variables.py index f99fb938..2621c045 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -27,6 +27,7 @@ from xarray import DataArray, Dataset, broadcast from xarray.core.coordinates import DatasetCoordinates from xarray.core.indexes import Indexes +from xarray.core.types import JoinOptions from xarray.core.utils import Frozen import linopy.expressions as expressions @@ -321,6 +322,23 @@ def to_linexpr( ds = Dataset({"coeffs": coefficient, "vars": self.labels}).expand_dims( TERM_DIM, -1 ) + # In v1 mode, set coeffs=NaN and const=NaN where the variable is + # absent so that absence propagates through arithmetic (consistent + # with expression path where shift/where/reindex fill with FILL_VALUE) + if options["arithmetic_convention"] == "v1": + absent = self.labels == -1 + if absent.any(): + nan_fill = DataArray( + np.where(absent, np.nan, 0.0), coords=self.labels.coords + ) + coeff_fill = DataArray( + np.where(absent, np.nan, coefficient.values), + coords=self.labels.coords, + ) + ds = ds.assign( + const=nan_fill, + coeffs=coeff_fill.expand_dims(TERM_DIM, -1), + ) return expressions.LinearExpression(ds, self.model) def __repr__(self) -> str: @@ -403,7 +421,11 @@ def __mul__(self, other: SideLike) -> ExpressionLike: if isinstance(other, Variable | ScalarVariable): return self.to_linexpr() * other - return self.to_linexpr(other) + # Fast path for scalars: build expression directly with coefficient + if np.isscalar(other): + return self.to_linexpr(other) + + return self.to_linexpr() * other except TypeError: return NotImplemented @@ -566,7 +588,7 @@ def __contains__(self, value: str) -> bool: return self.data.__contains__(value) def add( - self, other: SideLike, join: str | None = None + self, other: SideLike, join: JoinOptions | None = None ) -> LinearExpression | QuadraticExpression: """ Add variables to linear expressions or other variables. @@ -583,7 +605,7 @@ def add( return self.to_linexpr().add(other, join=join) def sub( - self, other: SideLike, join: str | None = None + self, other: SideLike, join: JoinOptions | None = None ) -> LinearExpression | QuadraticExpression: """ Subtract linear expressions or other variables from the variables. @@ -600,7 +622,7 @@ def sub( return self.to_linexpr().sub(other, join=join) def mul( - self, other: ConstantLike, join: str | None = None + self, other: ConstantLike, join: JoinOptions | None = None ) -> LinearExpression | QuadraticExpression: """ Multiply variables with a coefficient. @@ -617,7 +639,7 @@ def mul( return self.to_linexpr().mul(other, join=join) def div( - self, other: ConstantLike, join: str | None = None + self, other: ConstantLike, join: JoinOptions | None = None ) -> LinearExpression | QuadraticExpression: """ Divide variables with a coefficient. @@ -633,7 +655,7 @@ def div( """ return self.to_linexpr().div(other, join=join) - def le(self, rhs: SideLike, join: str | None = None) -> Constraint: + def le(self, rhs: SideLike, join: JoinOptions | None = None) -> Constraint: """ Less than or equal constraint. @@ -648,7 +670,7 @@ def le(self, rhs: SideLike, join: str | None = None) -> Constraint: """ return self.to_linexpr().le(rhs, join=join) - def ge(self, rhs: SideLike, join: str | None = None) -> Constraint: + def ge(self, rhs: SideLike, join: JoinOptions | None = None) -> Constraint: """ Greater than or equal constraint. @@ -663,7 +685,7 @@ def ge(self, rhs: SideLike, join: str | None = None) -> Constraint: """ return self.to_linexpr().ge(rhs, join=join) - def eq(self, rhs: SideLike, join: str | None = None) -> Constraint: + def eq(self, rhs: SideLike, join: JoinOptions | None = None) -> Constraint: """ Equality constraint. @@ -1139,19 +1161,30 @@ def where( def fillna( self, - fill_value: ScalarVariable | dict[str, str | float | int] | Variable | Dataset, - ) -> Variable: + fill_value: int + | float + | ScalarVariable + | dict[str, str | float | int] + | Variable + | Dataset, + ) -> Variable | expressions.LinearExpression: """ - Fill missing values with a variable. + Fill missing values. - This operation call ``xarray.DataArray.fillna`` but ensures preserving - the linopy.Variable type. + When ``fill_value`` is numeric, absent variable slots are replaced + with that constant and a :class:`LinearExpression` is returned + (since a constant is not a variable). When ``fill_value`` is a + Variable, the result stays a Variable. Parameters ---------- - fill_value : Variable/ScalarVariable - Variable to use for filling. + fill_value : numeric, Variable, or ScalarVariable + Value to use for filling. Numeric values produce a + LinearExpression; Variable values produce a Variable. """ + if isinstance(fill_value, int | float | np.integer | np.floating): + expr = self.to_linexpr() + return expr.fillna(fill_value) return self.where(~self.isnull(), fill_value) def ffill(self, dim: str, limit: None = None) -> Variable: @@ -1250,6 +1283,41 @@ def equals(self, other: Variable) -> bool: shift = varwrap(Dataset.shift, fill_value=_fill_value) + def reindex( + self, + indexers: Mapping[Any, Any] | None = None, + **indexers_kwargs: Any, + ) -> Variable: + """ + Reindex the variable, filling with sentinel values. + + Always fills with labels=-1, lower=NaN, upper=NaN to preserve + valid label references. + """ + return self.__class__( + self.data.reindex(indexers, fill_value=self._fill_value, **indexers_kwargs), + self.model, + self.name, + ) + + def reindex_like( + self, + other: Any, + **kwargs: Any, + ) -> Variable: + """Reindex like another object, filling with sentinel values.""" + if isinstance(other, DataArray): + ref = other.to_dataset(name="__tmp__") + elif isinstance(other, Dataset): + ref = other + else: + ref = other.data + return self.__class__( + self.data.reindex_like(ref, fill_value=self._fill_value, **kwargs), + self.model, + self.name, + ) + swap_dims = varwrap(Dataset.swap_dims) set_index = varwrap(Dataset.set_index) diff --git a/test/conftest.py b/test/conftest.py index ee20cdc2..a2f61ca7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,6 +3,8 @@ from __future__ import annotations import os +import warnings +from collections.abc import Generator from typing import TYPE_CHECKING import pandas as pd @@ -25,6 +27,12 @@ def pytest_addoption(parser: pytest.Parser) -> None: def pytest_configure(config: pytest.Config) -> None: """Configure pytest with custom markers and behavior.""" config.addinivalue_line("markers", "gpu: marks tests as requiring GPU hardware") + config.addinivalue_line( + "markers", "legacy_only: test runs only under the legacy arithmetic convention" + ) + config.addinivalue_line( + "markers", "v1_only: test runs only under the v1 arithmetic convention" + ) # Set environment variable so test modules can check if GPU tests are enabled # This is needed because parametrize happens at import time @@ -57,6 +65,38 @@ def pytest_collection_modifyitems( item.add_marker(pytest.mark.gpu) +@pytest.fixture(autouse=True, params=["legacy", "v1"]) +def convention(request: pytest.FixtureRequest) -> Generator[str, None, None]: + """ + Run every test under both arithmetic conventions by default. + + Tests marked ``@pytest.mark.legacy_only`` or ``@pytest.mark.v1_only`` + are automatically skipped for the other convention. + + Under "legacy", LinopyDeprecationWarning is suppressed so that the test + output stays clean. Dedicated tests in test_convention.py verify that + these warnings are actually emitted. + """ + import linopy + from linopy.config import LinopyDeprecationWarning + + item = request.node + if item.get_closest_marker("legacy_only") and request.param != "legacy": + pytest.skip("legacy-only test") + if item.get_closest_marker("v1_only") and request.param != "v1": + pytest.skip("v1-only test") + + old = linopy.options["arithmetic_convention"] + linopy.options["arithmetic_convention"] = request.param + if request.param == "legacy": + with warnings.catch_warnings(): + warnings.simplefilter("ignore", LinopyDeprecationWarning) + yield request.param + else: + yield request.param + linopy.options["arithmetic_convention"] = old + + @pytest.fixture def m() -> Model: from linopy import Model diff --git a/test/test_algebraic_properties.py b/test/test_algebraic_properties.py new file mode 100644 index 00000000..17f0fafa --- /dev/null +++ b/test/test_algebraic_properties.py @@ -0,0 +1,826 @@ +""" +Algebraic properties of linopy arithmetic. + +All standard algebraic laws should hold for linopy expressions, +including in the presence of absent slots (NaN from shift/where/reindex). + +This file serves as both specification and test suite. + +Notation: + x[A], y[A], z[A] — linopy variables with dimension A + g[A,B] — linopy variable with dimensions A and B + c[B] — constant (DataArray) with dimension B + s — scalar (int/float) + xs — x.shift(time=1), variable with absent slot + +SPECIFICATION +============= + +1. Commutativity + a + b == b + a for any linopy operands a, b + a * c == c * a for variable/expression a, constant c + +2. Associativity + (a + b) + c == a + (b + c) for any linopy operands a, b, c + Including with absent slots: (xs + s) + y == xs + (s + y) + +3. Distributivity + c * (a + b) == c*a + c*b for constant c, linopy operands a, b + Including with absent slots: s * (xs + c) == s*xs + s*c + +4. Identity + a + 0 == a additive identity + a * 1 == a multiplicative identity + +5. Negation + a - b == a + (-b) subtraction is addition of negation + -(-a) == a double negation + +6. Zero + a * 0 == 0 multiplication by zero + +7. NaN / absent slot behavior + Addition uses additive identity (0) to fill NaN const: + xs + s revives absent slot with const=s + xs - s revives absent slot with const=-s + Multiplication propagates NaN: + xs * s keeps absent slot absent + xs / s keeps absent slot absent + Merge (expression + expression): + xs + y — absent x term doesn't poison valid y term + xs + ys — fully absent when ALL terms absent + Variable and expression paths are consistent. + +8. fillna + Variable.fillna(numeric) returns LinearExpression + Expression.fillna(value) fills const at absent slots + +9. Named methods with fill_value + .add(v, fill_value=f) fills const before adding + .mul(v, fill_value=f) fills const before multiplying + +10. Expression–expression algebraic laws + (2x+3) + (4y+1) == (4y+1) + (2x+3) commutativity + ((2x+1)+(3y+2))+(4z+3) == (2x+1)+((3y+2)+(4z+3)) associativity + +11. Division distributivity + (a + b) / c == a/c + b/c + +12. Subtraction distributivity + c * (a - b) == c*a - c*b + +13. Negative scalar distributivity + -s * (a + b) == -s*a + (-s*b) + +14. Multi-step constant folding + (x + 3) * 2 + 1 should equal 2*x + 7 + +15. Mixed-type commutativity + Variable + Expression == Expression + Variable +""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from linopy import Model +from linopy.expressions import LinearExpression +from linopy.variables import Variable + + +@pytest.fixture +def m() -> Model: + return Model() + + +@pytest.fixture +def time() -> pd.RangeIndex: + return pd.RangeIndex(3, name="time") + + +@pytest.fixture +def tech() -> pd.Index: + return pd.Index(["solar", "wind"], name="tech") + + +@pytest.fixture +def x(m: Model, time: pd.RangeIndex) -> Variable: + """Variable with dims [time].""" + return m.add_variables(lower=0, coords=[time], name="x") + + +@pytest.fixture +def y(m: Model, time: pd.RangeIndex) -> Variable: + """Variable with dims [time].""" + return m.add_variables(lower=0, coords=[time], name="y") + + +@pytest.fixture +def z(m: Model, time: pd.RangeIndex) -> Variable: + """Variable with dims [time].""" + return m.add_variables(lower=0, coords=[time], name="z") + + +@pytest.fixture +def g(m: Model, time: pd.RangeIndex, tech: pd.Index) -> Variable: + """Variable with dims [time, tech].""" + return m.add_variables(lower=0, coords=[time, tech], name="g") + + +@pytest.fixture +def c(tech: pd.Index) -> xr.DataArray: + """Constant (DataArray) with dims [tech].""" + return xr.DataArray([2.0, 3.0], dims=["tech"], coords={"tech": tech}) + + +def assert_linequal(a: LinearExpression, b: LinearExpression) -> None: + """ + Assert two linear expressions are algebraically equivalent. + + Checks dimensions, coordinates, coefficients, variable references, and constants. + """ + assert set(a.dims) == set(b.dims), f"dims differ: {a.dims} vs {b.dims}" + for dim in a.dims: + if isinstance(dim, str) and dim.startswith("_"): + continue + np.testing.assert_array_equal( + sorted(a.coords[dim].values), sorted(b.coords[dim].values) + ) + # Simplify both to canonical form for coefficient/variable comparison + a_s = a.simplify() + b_s = b.simplify() + assert a_s.nterm == b_s.nterm, f"nterm differs: {a_s.nterm} vs {b_s.nterm}" + np.testing.assert_array_almost_equal( + np.sort(a_s.coeffs.values, axis=None), + np.sort(b_s.coeffs.values, axis=None), + err_msg="coefficients differ", + ) + np.testing.assert_array_equal( + np.sort(a_s.vars.values, axis=None), + np.sort(b_s.vars.values, axis=None), + ) + np.testing.assert_array_almost_equal( + a.const.values, b.const.values, err_msg="constants differ" + ) + + +# ============================================================ +# 1. Commutativity +# ============================================================ + + +class TestCommutativity: + def test_add_var_var(self, x: Variable, y: Variable) -> None: + """X + y == y + x""" + assert_linequal(x + y, y + x) + + def test_mul_var_constant(self, g: Variable, c: xr.DataArray) -> None: + """G * c == c * g""" + assert_linequal(g * c, c * g) + + def test_add_var_constant(self, g: Variable, c: xr.DataArray) -> None: + """G + c == c + g""" + assert_linequal(g + c, c + g) + + def test_add_var_scalar(self, x: Variable) -> None: + """X + 5 == 5 + x""" + assert_linequal(x + 5, 5 + x) + + def test_mul_var_scalar(self, x: Variable) -> None: + """X * 3 == 3 * x""" + assert_linequal(x * 3, 3 * x) + + +# ============================================================ +# 2. Associativity +# ============================================================ + + +class TestAssociativity: + def test_add_same_dims(self, x: Variable, y: Variable, z: Variable) -> None: + """(x + y) + z == x + (y + z)""" + assert_linequal((x + y) + z, x + (y + z)) + + def test_add_with_constant(self, x: Variable, g: Variable, c: xr.DataArray) -> None: + """(x[A] + c[B]) + g[A,B] == x[A] + (c[B] + g[A,B])""" + assert_linequal((x + c) + g, x + (c + g)) + + def test_add_shifted_scalar_var(self, x: Variable, y: Variable) -> None: + """(x.shift(1) + 5) + y == x.shift(1) + (5 + y)""" + lhs = (x.shift(time=1) + 5) + y + rhs = x.shift(time=1) + (5 + y) + assert_linequal(lhs, rhs) + + def test_add_shifted_scalar_var_reordered(self, x: Variable, y: Variable) -> None: + """(x.shift(1) + y) + 5 == x.shift(1) + (y + 5)""" + lhs = (x.shift(time=1) + y) + 5 + rhs = x.shift(time=1) + (y + 5) + assert_linequal(lhs, rhs) + + def test_add_three_scalars_shifted(self, x: Variable) -> None: + """(x.shift(1) + 3) + 7 == x.shift(1) + 10""" + lhs = (x.shift(time=1) + 3) + 7 + rhs = x.shift(time=1) + 10 + assert_linequal(lhs, rhs) + + def test_sub_shifted_scalar_var(self, x: Variable, y: Variable) -> None: + """(x.shift(1) - 5) + y == x.shift(1) + (y - 5)""" + lhs = (x.shift(time=1) - 5) + y + rhs = x.shift(time=1) + (y - 5) + assert_linequal(lhs, rhs) + + +# ============================================================ +# 3. Distributivity +# ============================================================ + + +class TestDistributivity: + def test_scalar(self, x: Variable, y: Variable) -> None: + """S * (x + y) == s*x + s*y""" + assert_linequal(3 * (x + y), 3 * x + 3 * y) + + def test_constant_subset_dims(self, g: Variable, c: xr.DataArray) -> None: + """c[B] * (g[A,B] + g[A,B]) == c*g + c*g""" + assert_linequal(c * (g + g), c * g + c * g) + + def test_constant_mixed_dims( + self, x: Variable, g: Variable, c: xr.DataArray + ) -> None: + """c[B] * (x[A] + g[A,B]) == c*x + c*g""" + assert_linequal(c * (x + g), c * x + c * g) + + def test_scalar_shifted_add_constant(self, x: Variable) -> None: + """3 * (x.shift(1) + 5) == 3*x.shift(1) + 15""" + lhs = 3 * (x.shift(time=1) + 5) + rhs = 3 * x.shift(time=1) + 15 + assert_linequal(lhs, rhs) + + def test_scalar_shifted_add_var(self, x: Variable, y: Variable) -> None: + """3 * (x.shift(1) + y) == 3*x.shift(1) + 3*y""" + lhs = 3 * (x.shift(time=1) + y) + rhs = 3 * x.shift(time=1) + 3 * y + assert_linequal(lhs, rhs) + + +# ============================================================ +# 4. Identity +# ============================================================ + + +class TestIdentity: + def test_additive(self, x: Variable) -> None: + """X + 0 == x""" + result = x + 0 + assert isinstance(result, LinearExpression) + assert (result.const == 0).all() + np.testing.assert_array_equal(result.coeffs.squeeze().values, [1, 1, 1]) + + def test_multiplicative(self, x: Variable) -> None: + """X * 1 == x""" + result = x * 1 + assert isinstance(result, LinearExpression) + np.testing.assert_array_equal(result.coeffs.squeeze().values, [1, 1, 1]) + + def test_additive_shifted(self, x: Variable) -> None: + """x.shift(1) + 0 revives absent slot as zero expression.""" + result = x.shift(time=1) + 0 + assert not result.isnull().values[0] + assert result.const.values[0] == 0 + + +# ============================================================ +# 5. Negation +# ============================================================ + + +class TestNegation: + def test_subtraction_is_add_negation(self, x: Variable, y: Variable) -> None: + """X - y == x + (-y)""" + assert_linequal(x - y, x + (-y)) + + def test_subtraction_definition(self, x: Variable, y: Variable) -> None: + """X - y == x + (-1) * y""" + assert_linequal(x - y, x + (-1) * y) + + def test_double_negation(self, x: Variable) -> None: + """-(-x) has same coefficients as x""" + result = -(-x) + np.testing.assert_array_equal( + result.coeffs.squeeze().values, + (1 * x).coeffs.squeeze().values, + ) + + +# ============================================================ +# 6. Zero +# ============================================================ + + +class TestZero: + def test_multiplication_by_zero(self, x: Variable) -> None: + """X * 0 has zero coefficients""" + result = x * 0 + assert (result.coeffs == 0).all() + + +# ============================================================ +# 7. NaN / absent slot behavior +# ============================================================ + + +class TestAbsentSlotAddition: + """Addition fills const with 0 (additive identity) → revives absent slots.""" + + def test_add_scalar_revives(self, x: Variable) -> None: + result = x.shift(time=1) + 5 + assert not result.isnull().values[0] + assert result.const.values[0] == 5 + + def test_add_array_revives(self, x: Variable) -> None: + arr = xr.DataArray([10.0, 20.0, 30.0], dims=["time"]) + result = (1 * x).shift(time=1) + arr + assert not result.isnull().values[0] + assert result.const.values[0] == 10.0 + + def test_sub_scalar_revives(self, x: Variable) -> None: + result = x.shift(time=1) - 5 + assert not result.isnull().values[0] + assert result.const.values[0] == -5 + + def test_add_zero_revives(self, x: Variable) -> None: + """+ 0 revives to a zero expression (not absent).""" + result = x.shift(time=1) + 0 + assert not result.isnull().values[0] + assert result.const.values[0] == 0 + + def test_variable_and_expression_paths_consistent_add(self, x: Variable) -> None: + var_result = x.shift(time=1) + 5 + expr_result = (1 * x).shift(time=1) + 5 + np.testing.assert_array_equal( + var_result.isnull().values, expr_result.isnull().values + ) + np.testing.assert_array_equal(var_result.const.values, expr_result.const.values) + + +class TestAbsentSlotMultiplication: + """Multiplication propagates NaN → absent stays absent.""" + + @pytest.mark.v1_only + def test_mul_scalar_propagates(self, x: Variable) -> None: + result = x.shift(time=1) * 3 + assert result.isnull().values[0] + assert not result.isnull().values[1] + + @pytest.mark.v1_only + def test_mul_array_propagates(self, x: Variable) -> None: + arr = xr.DataArray([2.0, 2.0, 2.0], dims=["time"]) + result = (1 * x).shift(time=1) * arr + assert result.isnull().values[0] + + @pytest.mark.v1_only + def test_div_scalar_propagates(self, x: Variable) -> None: + result = (1 * x).shift(time=1) / 2 + assert result.isnull().values[0] + + def test_variable_and_expression_paths_consistent_mul(self, x: Variable) -> None: + var_result = x.shift(time=1) * 3 + expr_result = (1 * x).shift(time=1) * 3 + np.testing.assert_array_equal( + var_result.isnull().values, expr_result.isnull().values + ) + + +class TestAbsentSlotMerge: + """Merging expressions: absent terms don't poison valid terms.""" + + def test_partial_absent(self, x: Variable, y: Variable) -> None: + """X + y.shift(1): x is valid everywhere → no absent slots.""" + result = x + (1 * y).shift(time=1) + assert not result.isnull().any() + + def test_all_absent(self, x: Variable, y: Variable) -> None: + """x.shift(1) + y.shift(1): all terms absent at time=0 → absent.""" + result = (1 * x).shift(time=1) + (1 * y).shift(time=1) + assert result.isnull().values[0] + assert not result.isnull().values[1] + + def test_shifted_const_lost(self, x: Variable, y: Variable) -> None: + """X + (y+5).shift(1): shifted constant is lost at the gap.""" + result = x + (1 * y + 5).shift(time=1) + # time=0: only x's const (0), shifted 5 is lost + assert result.const.values[0] == 0 + # time=1: both consts survive (0 + 5 = 5) + assert result.const.values[1] == 5 + + +class TestAbsentSlotMixed: + """Combined add/mul with absent slots.""" + + def test_add_then_mul(self, x: Variable) -> None: + """(x.shift(1) + 5) * 3 → +15 at absent slot.""" + result = (x.shift(time=1) + 5) * 3 + assert not result.isnull().values[0] + assert result.const.values[0] == 15 + + def test_mul_then_add(self, x: Variable) -> None: + """x.shift(1) * 3 + 5 → +5 at absent slot.""" + result = x.shift(time=1) * 3 + 5 + assert not result.isnull().values[0] + assert result.const.values[0] == 5 + + def test_where_add_revives(self, x: Variable) -> None: + mask = xr.DataArray([True, False, True], dims=["time"]) + result = (1 * x).where(mask) + 10 + assert not result.isnull().any() + assert result.const.values[1] == 10 + + @pytest.mark.v1_only + def test_where_mul_propagates(self, x: Variable) -> None: + mask = xr.DataArray([True, False, True], dims=["time"]) + result = (1 * x).where(mask) * 3 + assert not result.isnull().values[0] + assert result.isnull().values[1] + assert not result.isnull().values[2] + + +# ============================================================ +# 8. fillna +# ============================================================ + + +class TestFillNA: + """fillna revives absent slots with explicit values.""" + + def test_variable_fillna_numeric_returns_expression(self, x: Variable) -> None: + result = x.shift(time=1).fillna(0) + assert isinstance(result, LinearExpression) + + def test_variable_fillna_revives(self, x: Variable) -> None: + result = x.shift(time=1).fillna(0) + assert not result.isnull().any() + assert result.const.values[0] == 0 + + @pytest.mark.v1_only + def test_variable_fillna_custom_value(self, x: Variable) -> None: + result = x.shift(time=1).fillna(42) + assert result.const.values[0] == 42 + assert result.const.values[1] == 0 # valid slots unaffected + + def test_expression_fillna_revives(self, x: Variable) -> None: + result = (1 * x).shift(time=1).fillna(0) + 5 + assert not result.isnull().any() + assert result.const.values[0] == 5 + + def test_variable_fillna_variable_returns_variable( + self, x: Variable, y: Variable + ) -> None: + result = x.shift(time=1).fillna(y) + assert isinstance(result, Variable) + + def test_fillna_then_add_equals_fillna_sum(self, x: Variable) -> None: + """fillna(0) + 5 == fillna(5) at absent slots.""" + a = (1 * x).shift(time=1).fillna(0) + 5 + b = (1 * x).shift(time=1).fillna(5) + assert a.const.values[0] == 5 + assert b.const.values[0] == 5 + + +# ============================================================ +# 9. Named methods with fill_value +# ============================================================ + + +class TestFillValueParam: + """Named methods (.add, .sub, .mul, .div) accept fill_value.""" + + def test_add_fill_value(self, x: Variable) -> None: + expr = (1 * x).shift(time=1) + result = expr.add(5, fill_value=0) + assert not result.isnull().any() + assert result.const.values[0] == 5 + + def test_sub_fill_value(self, x: Variable) -> None: + expr = (1 * x).shift(time=1) + result = expr.sub(5, fill_value=0) + assert not result.isnull().any() + assert result.const.values[0] == -5 + + def test_mul_fill_value(self, x: Variable) -> None: + expr = (1 * x).shift(time=1) + result = expr.mul(3, fill_value=0) + assert not result.isnull().any() + assert result.const.values[0] == 0 + + def test_div_fill_value(self, x: Variable) -> None: + expr = (1 * x).shift(time=1) + result = expr.div(2, fill_value=0) + assert not result.isnull().any() + assert result.const.values[0] == 0 + + def test_add_without_fill_value_still_revives(self, x: Variable) -> None: + """add() always fills const with 0 (additive identity).""" + expr = (1 * x).shift(time=1) + result = expr.add(5) + assert not result.isnull().values[0] + assert result.const.values[0] == 5 + + def test_fill_value_only_affects_absent(self, x: Variable) -> None: + expr = (1 * x).shift(time=1) + result = expr.add(5, fill_value=0) + assert result.const.values[1] == 5 # valid slot: 0 + 5 + assert result.coeffs.values[1, 0] == 1 # coeff unchanged + + +# ============================================================ +# 10. Division with absent slots +# ============================================================ + + +class TestDivisionAbsentSlots: + """Division propagates NaN same as multiplication.""" + + @pytest.mark.v1_only + def test_div_scalar_propagates(self, x: Variable) -> None: + result = (1 * x).shift(time=1) / 2 + assert result.isnull().values[0] + assert not result.isnull().values[1] + + @pytest.mark.v1_only + def test_div_array_propagates(self, x: Variable) -> None: + arr = xr.DataArray([2.0, 2.0, 2.0], dims=["time"]) + result = (1 * x).shift(time=1) / arr + assert result.isnull().values[0] + + def test_div_consistent_paths(self, x: Variable) -> None: + """Variable and expression paths give same result for division.""" + var_result = x.shift(time=1) / 2 + expr_result = (1 * x).shift(time=1) / 2 + assert_linequal(var_result, expr_result) + + def test_div_equals_mul_reciprocal(self, x: Variable) -> None: + """Shifted / 2 == shifted * 0.5""" + shifted = (1 * x).shift(time=1) + assert_linequal(shifted / 2, shifted * 0.5) + + +# ============================================================ +# 11. Subtraction with absent slots +# ============================================================ + + +class TestSubtractionAbsentSlots: + """Subtraction with shifted coords preserves associativity.""" + + def test_sub_scalar_revives(self, x: Variable) -> None: + result = x.shift(time=1) - 5 + assert not result.isnull().values[0] + assert result.const.values[0] == -5 + + def test_sub_associativity_shifted(self, x: Variable, y: Variable) -> None: + """(x.shift(1) - 5) + y == x.shift(1) + (y - 5)""" + xs = x.shift(time=1) + assert_linequal((xs - 5) + y, xs + (y - 5)) + + def test_sub_equals_add_neg_shifted(self, x: Variable) -> None: + """x.shift(1) - 5 == x.shift(1) + (-5)""" + xs = x.shift(time=1) + assert_linequal(xs - 5, xs + (-5)) + + +# ============================================================ +# 12. Multi-dimensional absent slots +# ============================================================ + + +@pytest.fixture +def g2(m: Model, time: pd.RangeIndex, tech: pd.Index) -> Variable: + """Second variable with dims [time, tech].""" + return m.add_variables(lower=0, coords=[time, tech], name="g2") + + +class TestMultiDimensionalAbsentSlots: + """2D variables with shift: add revives, mul propagates.""" + + def test_2d_add_revives(self, g: Variable) -> None: + shifted = (1 * g).shift(time=1) + result = shifted + 5 + assert not result.isnull().isel(time=0).any() + assert (result.const.isel(time=0) == 5).all() + + @pytest.mark.v1_only + def test_2d_mul_propagates(self, g: Variable) -> None: + shifted = (1 * g).shift(time=1) + result = shifted * 3 + assert result.isnull().isel(time=0).all() + + def test_2d_associativity(self, g: Variable, g2: Variable) -> None: + """(g.shift(1) + g2) + 5 == g.shift(1) + (g2 + 5) in 2D.""" + gs = g.shift(time=1) + assert_linequal((gs + g2) + 5, gs + (g2 + 5)) + + def test_2d_distributivity(self, g: Variable, g2: Variable) -> None: + """2 * (g.shift(1) + g2) == 2*g.shift(1) + 2*g2 in 2D.""" + gs = g.shift(time=1) + assert_linequal(2 * (gs + g2), 2 * gs + 2 * g2) + + +# ============================================================ +# 13. Expression–expression algebraic laws +# ============================================================ + + +class TestExpressionExpressionAlgebra: + """Algebraic laws where both operands are multi-term expressions.""" + + def test_commutativity(self, x: Variable, y: Variable) -> None: + """(2x+3) + (4y+1) == (4y+1) + (2x+3)""" + a = 2 * x + 3 + b = 4 * y + 1 + assert_linequal(a + b, b + a) + + def test_associativity(self, x: Variable, y: Variable, z: Variable) -> None: + """((2x+1)+(3y+2))+(4z+3) == (2x+1)+((3y+2)+(4z+3))""" + a = 2 * x + 1 + b = 3 * y + 2 + c = 4 * z + 3 + assert_linequal((a + b) + c, a + (b + c)) + + def test_commutativity_shifted(self, x: Variable, y: Variable) -> None: + """Shifted expressions: (xs+5) + (2y+1) == (2y+1) + (xs+5)""" + a = x.shift(time=1) + 5 + b = 2 * y + 1 + assert_linequal(a + b, b + a) + + def test_associativity_shifted(self, x: Variable, y: Variable, z: Variable) -> None: + """Shifted: ((xs+5)+(2y))+(3z) == (xs+5)+((2y)+(3z))""" + a = x.shift(time=1) + 5 + b = 2 * y + c = 3 * z + assert_linequal((a + b) + c, a + (b + c)) + + def test_four_term_associativity( + self, x: Variable, y: Variable, z: Variable + ) -> None: + """((a+b)+c)+d == a+(b+(c+d)) with four expression operands.""" + a = 1 * x + 1 + b = 2 * y + 2 + c = 3 * z + 3 + d = 4 * x + 4 + assert_linequal(((a + b) + c) + d, a + (b + (c + d))) + + +# ============================================================ +# 14. Division distributivity +# ============================================================ + + +class TestDivisionDistributivity: + """(a + b) / c == a/c + b/c""" + + def test_scalar(self, x: Variable, y: Variable) -> None: + assert_linequal((x + y) / 2, x / 2 + y / 2) + + def test_array(self, g: Variable, g2: Variable, c: xr.DataArray) -> None: + assert_linequal((g + g2) / c, g / c + g2 / c) + + def test_with_constant_offset(self, x: Variable, y: Variable) -> None: + """(x + y + 6) / 2 == x/2 + y/2 + 3""" + assert_linequal((x + y + 6) / 2, x / 2 + y / 2 + 3) + + def test_shifted(self, x: Variable, y: Variable) -> None: + """(xs + y) / 2 == xs/2 + y/2""" + xs = x.shift(time=1) + assert_linequal((xs + y) / 2, xs / 2 + y / 2) + + +# ============================================================ +# 15. Subtraction distributivity +# ============================================================ + + +class TestSubtractionDistributivity: + """c * (a - b) == c*a - c*b""" + + def test_scalar(self, x: Variable, y: Variable) -> None: + assert_linequal(3 * (x - y), 3 * x - 3 * y) + + def test_array(self, g: Variable, g2: Variable, c: xr.DataArray) -> None: + assert_linequal(c * (g - g2), c * g - c * g2) + + def test_shifted(self, x: Variable, y: Variable) -> None: + xs = x.shift(time=1) + assert_linequal(3 * (xs - y), 3 * xs - 3 * y) + + def test_sub_then_div(self, x: Variable, y: Variable) -> None: + """(a - b) / c == a/c - b/c""" + assert_linequal((x - y) / 2, x / 2 - y / 2) + + +# ============================================================ +# 16. Negative scalar distributivity +# ============================================================ + + +class TestNegativeDistributivity: + """-s * (a + b) == -s*a + (-s*b)""" + + def test_negative_scalar(self, x: Variable, y: Variable) -> None: + assert_linequal(-3 * (x + y), -3 * x + (-3) * y) + + def test_negative_one(self, x: Variable, y: Variable) -> None: + """-(x + y) == -x + (-y)""" + assert_linequal(-(x + y), -x + (-y)) + + def test_negative_scalar_with_constant(self, x: Variable) -> None: + """-2 * (x + 5) == -2*x + (-10)""" + assert_linequal(-2 * (x + 5), -2 * x + (-10)) + + def test_negative_scalar_shifted(self, x: Variable, y: Variable) -> None: + xs = x.shift(time=1) + assert_linequal(-3 * (xs + y), -3 * xs + (-3) * y) + + def test_negate_expression(self, x: Variable, y: Variable) -> None: + """-(2x + 3y + 5) == -2x + (-3y) + (-5)""" + assert_linequal(-(2 * x + 3 * y + 5), -2 * x + (-3) * y + (-5)) + + +# ============================================================ +# 17. Multi-step constant folding +# ============================================================ + + +class TestMultiStepArithmetic: + """Chains of operations that combine constants through multiple steps.""" + + def test_add_then_mul_then_add(self, x: Variable) -> None: + """(x + 3) * 2 + 1 == 2*x + 7""" + assert_linequal((x + 3) * 2 + 1, 2 * x + 7) + + def test_mul_then_add_then_mul(self, x: Variable) -> None: + """(x * 2 + 1) * 3 == 6*x + 3""" + assert_linequal((x * 2 + 1) * 3, 6 * x + 3) + + def test_sub_then_mul(self, x: Variable) -> None: + """(x - 4) * 3 == 3*x - 12""" + assert_linequal((x - 4) * 3, 3 * x - 12) + + def test_div_then_add(self, x: Variable) -> None: + """(x + 6) / 2 + 1 == x/2 + 4""" + assert_linequal((x + 6) / 2 + 1, x / 2 + 4) + + def test_chain_three_ops(self, x: Variable) -> None: + """((x + 1) * 2 - 3) * 4 == 8*x + (-4)""" + assert_linequal(((x + 1) * 2 - 3) * 4, 8 * x + (-4)) + + def test_chain_with_two_vars(self, x: Variable, y: Variable) -> None: + """(2x + 3y + 5) * 2 - 4 == 4x + 6y + 6""" + assert_linequal((2 * x + 3 * y + 5) * 2 - 4, 4 * x + 6 * y + 6) + + def test_chain_shifted(self, x: Variable) -> None: + """(xs + 3) * 2 + 1: at absent slot const should be 7.""" + result = (x.shift(time=1) + 3) * 2 + 1 + assert result.const.values[0] == 7 + + +# ============================================================ +# 18. Mixed-type commutativity (Variable + Expression) +# ============================================================ + + +class TestMixedTypeCommutativity: + """Variable + Expression == Expression + Variable and similar mixes.""" + + def test_var_plus_expr(self, x: Variable, y: Variable) -> None: + """X + (2*y + 3) == (2*y + 3) + x""" + expr = 2 * y + 3 + assert_linequal(x + expr, expr + x) + + def test_var_minus_expr(self, x: Variable, y: Variable) -> None: + """X - (2*y + 3) == x + (-(2*y + 3))""" + expr = 2 * y + 3 + assert_linequal(x - expr, x + (-expr)) + + def test_expr_plus_var_plus_expr( + self, x: Variable, y: Variable, z: Variable + ) -> None: + """(2x+1) + y + (3z+2) == y + (2x+1) + (3z+2)""" + a = 2 * x + 1 + b = 3 * z + 2 + assert_linequal((a + y) + b, (y + a) + b) + + def test_shifted_var_plus_expr(self, x: Variable, y: Variable) -> None: + """x.shift(1) + (2*y + 1) == (2*y + 1) + x.shift(1)""" + xs = x.shift(time=1) + expr = 2 * y + 1 + assert_linequal(xs + expr, expr + xs) + + def test_var_plus_multiterm_expr( + self, x: Variable, y: Variable, z: Variable + ) -> None: + """X + (2y + 3z + 5) == (2y + 3z + 5) + x""" + expr = 2 * y + 3 * z + 5 + assert_linequal(x + expr, expr + x) diff --git a/test/test_common.py b/test/test_common.py index f1190024..7fa70489 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -649,7 +649,9 @@ def test_get_dims_with_index_levels() -> None: assert get_dims_with_index_levels(ds5) == [] +@pytest.mark.legacy_only def test_align(x: Variable, u: Variable) -> None: # noqa: F811 + """Legacy: align() defaults to inner join for mismatched coords.""" alpha = xr.DataArray([1, 2], [[1, 2]]) beta = xr.DataArray( [1, 2, 3], @@ -663,7 +665,7 @@ def test_align(x: Variable, u: Variable) -> None: # noqa: F811 ], ) - # inner join + # inner join (default under legacy) x_obs, alpha_obs = align(x, alpha) assert isinstance(x_obs, Variable) assert x_obs.shape == alpha_obs.shape == (1,) @@ -692,6 +694,55 @@ def test_align(x: Variable, u: Variable) -> None: # noqa: F811 assert_linequal(expr_obs, expr.loc[[1]]) +@pytest.mark.v1_only +def test_align_v1(x: Variable, u: Variable) -> None: # noqa: F811 + """V1: align() defaults to exact join; explicit join= needed for mismatched coords.""" + alpha = xr.DataArray([1, 2], [[1, 2]]) + beta = xr.DataArray( + [1, 2, 3], + [ + ( + "dim_3", + pd.MultiIndex.from_tuples( + [(1, "b"), (2, "b"), (1, "c")], names=["level1", "level2"] + ), + ) + ], + ) + + # exact join raises on mismatched coords + with pytest.raises(ValueError, match="exact"): + align(x, alpha) + + # explicit inner join + x_obs, alpha_obs = align(x, alpha, join="inner") + assert isinstance(x_obs, Variable) + assert x_obs.shape == alpha_obs.shape == (1,) + assert_varequal(x_obs, x.loc[[1]]) + + # left-join + x_obs, alpha_obs = align(x, alpha, join="left") + assert x_obs.shape == alpha_obs.shape == (2,) + assert isinstance(x_obs, Variable) + assert_varequal(x_obs, x) + assert_equal(alpha_obs, DataArray([np.nan, 1], [[0, 1]])) + + # multiindex with explicit inner join + beta_obs, u_obs = align(beta, u, join="inner") + assert u_obs.shape == beta_obs.shape == (2,) + assert isinstance(u_obs, Variable) + assert_varequal(u_obs, u.loc[[(1, "b"), (2, "b")]]) + assert_equal(beta_obs, beta.loc[[(1, "b"), (2, "b")]]) + + # with linear expression, explicit inner join + expr = 20 * x + x_obs, expr_obs, alpha_obs = align(x, expr, alpha, join="inner") + assert x_obs.shape == alpha_obs.shape == (1,) + assert expr_obs.shape == (1, 1) # _term dim + assert isinstance(expr_obs, LinearExpression) + assert_linequal(expr_obs, expr.loc[[1]]) + + def test_is_constant() -> None: model = Model() index = pd.Index(range(10), name="t") diff --git a/test/test_constraints.py b/test/test_constraints.py index 9a467c8c..67cd592b 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -175,9 +175,11 @@ def test_constraint_rhs_lower_dim(rhs_factory: Any) -> None: pytest.param(lambda m: pd.DataFrame(np.ones((5, 3))), id="dataframe"), ], ) +@pytest.mark.legacy_only def test_constraint_rhs_higher_dim_constant_warns( rhs_factory: Any, caplog: Any ) -> None: + """Legacy: higher-dim constant RHS warns about dimensions.""" m = Model() x = m.add_variables(coords=[range(5)], name="x") @@ -186,8 +188,19 @@ def test_constraint_rhs_higher_dim_constant_warns( assert "dimensions" in caplog.text +@pytest.mark.v1_only +def test_constraint_rhs_higher_dim_constant_broadcasts_v1() -> None: + """V1: higher-dim constant RHS broadcasts (creates redundant constraints).""" + m = Model() + x = m.add_variables(coords=[range(5)], name="x") + rhs = xr.DataArray(np.ones((5, 3)), dims=["dim_0", "extra"]) + c = m.add_constraints(x >= rhs, name="broadcast_con") + assert "extra" in c.dims + + +@pytest.mark.legacy_only def test_constraint_rhs_higher_dim_dataarray_reindexes() -> None: - """DataArray RHS with extra dims reindexes to expression coords (no raise).""" + """Legacy: DataArray RHS with extra dims reindexes to expression coords.""" m = Model() x = m.add_variables(coords=[range(5)], name="x") rhs = xr.DataArray(np.ones((5, 3)), dims=["dim_0", "extra"]) @@ -347,6 +360,8 @@ def test_sanitize_infinities() -> None: class TestConstraintCoordinateAlignment: + """Tests for constraint behavior when variable and RHS coordinates differ.""" + @pytest.fixture(params=["xarray", "pandas_series"], ids=["da", "series"]) def subset(self, request: Any) -> xr.DataArray | pd.Series: if request.param == "xarray": @@ -365,15 +380,36 @@ def superset(self, request: Any) -> xr.DataArray | pd.Series: np.arange(25, dtype=float), index=pd.Index(range(25), name="dim_2") ) - def test_var_le_subset(self, v: Variable, subset: xr.DataArray) -> None: + # -- var <= subset -- + + @pytest.mark.legacy_only + def test_var_le_subset_fills_nan(self, v: Variable, subset: xr.DataArray) -> None: con = v <= subset assert con.sizes["dim_2"] == v.sizes["dim_2"] assert con.rhs.sel(dim_2=1).item() == 10.0 assert con.rhs.sel(dim_2=3).item() == 30.0 assert np.isnan(con.rhs.sel(dim_2=0).item()) + @pytest.mark.v1_only + def test_var_le_subset_raises(self, v: Variable) -> None: + subset = xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + with pytest.raises(ValueError, match="exact"): + v <= subset + + @pytest.mark.v1_only + def test_var_le_subset_join_left(self, v: Variable) -> None: + subset = xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + con = v.to_linexpr().le(subset, join="left") + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert con.rhs.sel(dim_2=1).item() == 10.0 + assert con.rhs.sel(dim_2=3).item() == 30.0 + assert np.isnan(con.rhs.sel(dim_2=0).item()) + + # -- var comparison (all signs) with subset -- + + @pytest.mark.legacy_only @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) - def test_var_comparison_subset( + def test_var_comparison_subset_fills_nan( self, v: Variable, subset: xr.DataArray, sign: str ) -> None: if sign == LESS_EQUAL: @@ -386,7 +422,49 @@ def test_var_comparison_subset( assert con.rhs.sel(dim_2=1).item() == 10.0 assert np.isnan(con.rhs.sel(dim_2=0).item()) - def test_expr_le_subset(self, v: Variable, subset: xr.DataArray) -> None: + @pytest.mark.v1_only + @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) + def test_var_comparison_subset_raises(self, v: Variable, sign: str) -> None: + subset = xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + with pytest.raises(ValueError, match="exact"): + if sign == LESS_EQUAL: + v <= subset + elif sign == GREATER_EQUAL: + v >= subset + else: + v == subset + + @pytest.mark.v1_only + @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) + def test_var_comparison_subset_join_left(self, v: Variable, sign: str) -> None: + subset = xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + expr = v.to_linexpr() + if sign == LESS_EQUAL: + con = expr.le(subset, join="left") + elif sign == GREATER_EQUAL: + con = expr.ge(subset, join="left") + else: + con = expr.eq(subset, join="left") + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert con.rhs.sel(dim_2=1).item() == 10.0 + assert np.isnan(con.rhs.sel(dim_2=0).item()) + + @pytest.mark.v1_only + def test_var_comparison_subset_assign_coords(self, v: Variable) -> None: + """V1 pattern: use assign_coords to align before comparing.""" + target_coords = v.coords["dim_2"][:2] + subset = xr.DataArray( + [10.0, 30.0], dims=["dim_2"], coords={"dim_2": target_coords} + ) + con = v.loc[:1] <= subset + assert con.sizes["dim_2"] == 2 + assert con.rhs.sel(dim_2=0).item() == 10.0 + assert con.rhs.sel(dim_2=1).item() == 30.0 + + # -- expr <= subset -- + + @pytest.mark.legacy_only + def test_expr_le_subset_fills_nan(self, v: Variable, subset: xr.DataArray) -> None: expr = v + 5 con = expr <= subset assert con.sizes["dim_2"] == v.sizes["dim_2"] @@ -394,8 +472,28 @@ def test_expr_le_subset(self, v: Variable, subset: xr.DataArray) -> None: assert con.rhs.sel(dim_2=3).item() == pytest.approx(25.0) assert np.isnan(con.rhs.sel(dim_2=0).item()) + @pytest.mark.v1_only + def test_expr_le_subset_raises(self, v: Variable) -> None: + subset = xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + expr = v + 5 + with pytest.raises(ValueError, match="exact"): + expr <= subset + + @pytest.mark.v1_only + def test_expr_le_subset_join_left(self, v: Variable) -> None: + subset = xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + expr = v.to_linexpr() + 5 + con = expr.le(subset, join="left") + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert con.rhs.sel(dim_2=1).item() == pytest.approx(5.0) + assert con.rhs.sel(dim_2=3).item() == pytest.approx(25.0) + assert np.isnan(con.rhs.sel(dim_2=0).item()) + + # -- subset comparison var (reverse) -- + + @pytest.mark.legacy_only @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) - def test_subset_comparison_var( + def test_subset_comparison_var_fills_nan( self, v: Variable, subset: xr.DataArray, sign: str ) -> None: if sign == LESS_EQUAL: @@ -408,8 +506,23 @@ def test_subset_comparison_var( assert np.isnan(con.rhs.sel(dim_2=0).item()) assert con.rhs.sel(dim_2=1).item() == pytest.approx(10.0) + @pytest.mark.v1_only + @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) + def test_subset_comparison_var_raises(self, v: Variable, sign: str) -> None: + subset = xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + with pytest.raises(ValueError, match="exact"): + if sign == LESS_EQUAL: + subset <= v + elif sign == GREATER_EQUAL: + subset >= v + else: + subset == v + + # -- superset comparison var -- + + @pytest.mark.legacy_only @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL]) - def test_superset_comparison_var( + def test_superset_comparison_no_nan( self, v: Variable, superset: xr.DataArray, sign: str ) -> None: if sign == LESS_EQUAL: @@ -420,7 +533,31 @@ def test_superset_comparison_var( assert not np.isnan(con.lhs.coeffs.values).any() assert not np.isnan(con.rhs.values).any() - def test_constraint_rhs_extra_dims_broadcasts(self, v: Variable) -> None: + @pytest.mark.v1_only + @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL]) + def test_superset_comparison_var_raises(self, v: Variable, sign: str) -> None: + superset = xr.DataArray( + np.arange(25, dtype=float), dims=["dim_2"], coords={"dim_2": range(25)} + ) + with pytest.raises(ValueError, match="exact"): + if sign == LESS_EQUAL: + superset <= v + else: + superset >= v + + @pytest.mark.v1_only + def test_superset_comparison_join_inner(self, v: Variable) -> None: + superset = xr.DataArray( + np.arange(25, dtype=float), dims=["dim_2"], coords={"dim_2": range(25)} + ) + con = v.to_linexpr().le(superset, join="inner") + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(con.rhs.values).any() + + # -- extra dims -- + + @pytest.mark.legacy_only + def test_rhs_extra_dims_broadcasts(self, v: Variable) -> None: rhs = xr.DataArray( [[1.0, 2.0]], dims=["extra", "dim_2"], @@ -429,7 +566,52 @@ def test_constraint_rhs_extra_dims_broadcasts(self, v: Variable) -> None: c = v <= rhs assert "extra" in c.dims - def test_subset_constraint_solve_integration(self) -> None: + @pytest.mark.v1_only + def test_rhs_extra_dims_matching_broadcasts(self, v: Variable) -> None: + rhs = xr.DataArray( + np.ones((2, 20)), dims=["extra", "dim_2"], coords={"dim_2": range(20)} + ) + c = v <= rhs + assert "extra" in c.dims + + @pytest.mark.v1_only + def test_rhs_extra_dims_mismatched_raises(self, v: Variable) -> None: + rhs = xr.DataArray( + [[1.0, 2.0]], dims=["extra", "dim_2"], coords={"dim_2": [0, 1]} + ) + with pytest.raises(ValueError, match="exact"): + v <= rhs + + @pytest.mark.v1_only + def test_rhs_higher_dim_dataarray_matching_broadcasts(self) -> None: + """V1: DataArray RHS with extra dims broadcasts if shared dim coords match.""" + m = Model() + x = m.add_variables(coords=[range(5)], name="x") + rhs = xr.DataArray( + np.ones((5, 3)), + dims=["dim_0", "extra"], + coords={"dim_0": range(5)}, + ) + c = m.add_constraints(x >= rhs) + assert c.shape == (5, 3) + + @pytest.mark.v1_only + def test_rhs_higher_dim_dataarray_mismatched_raises(self) -> None: + """V1: DataArray RHS with mismatched shared dim coords raises.""" + m = Model() + x = m.add_variables(coords=[range(5)], name="x") + rhs = xr.DataArray( + np.ones((3, 3)), + dims=["dim_0", "extra"], + coords={"dim_0": [10, 11, 12]}, + ) + with pytest.raises(ValueError, match="exact"): + m.add_constraints(x >= rhs) + + # -- solver integration -- + + @pytest.mark.legacy_only + def test_subset_constraint_solve_implicit(self) -> None: if not available_solvers: pytest.skip("No solver available") solver = "highs" if "highs" in available_solvers else available_solvers[0] @@ -446,3 +628,23 @@ def test_subset_constraint_solve_integration(self) -> None: assert sol.sel(i=0).item() == pytest.approx(100.0) assert sol.sel(i=2).item() == pytest.approx(100.0) assert sol.sel(i=4).item() == pytest.approx(100.0) + + @pytest.mark.v1_only + def test_subset_constraint_solve_explicit_join(self) -> None: + if not available_solvers: + pytest.skip("No solver available") + solver = "highs" if "highs" in available_solvers else available_solvers[0] + m = Model() + coords = pd.RangeIndex(5, name="i") + x = m.add_variables(lower=0, upper=100, coords=[coords], name="x") + subset_ub = xr.DataArray([10.0, 20.0], dims=["i"], coords={"i": [1, 3]}) + # exact default raises — use explicit join="left" (NaN = no constraint) + m.add_constraints(x.to_linexpr().le(subset_ub, join="left"), name="subset_ub") + m.add_objective(x.sum(), sense="max") + m.solve(solver_name=solver) + sol = m.solution["x"] + assert sol.sel(i=1).item() == pytest.approx(10.0) + assert sol.sel(i=3).item() == pytest.approx(20.0) + assert sol.sel(i=0).item() == pytest.approx(100.0) + assert sol.sel(i=2).item() == pytest.approx(100.0) + assert sol.sel(i=4).item() == pytest.approx(100.0) diff --git a/test/test_convention.py b/test/test_convention.py new file mode 100644 index 00000000..e1099e81 --- /dev/null +++ b/test/test_convention.py @@ -0,0 +1,484 @@ +""" +Tests for the arithmetic convention system. + +Covers: +- Config validation (valid/invalid convention values, default) +- Deprecation warnings under legacy convention +- Scalar fast path consistency +- NaN edge cases (inf, -inf) +- Convention switching mid-session +- Variable.reindex() and Variable.reindex_like() +""" + +from __future__ import annotations + +import warnings + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +import linopy +from linopy import LinearExpression, Model, Variable +from linopy.config import ( + LinopyDeprecationWarning, + OptionSettings, + options, +) +from linopy.constraints import Constraint +from linopy.testing import assert_linequal + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def m() -> Model: + model = Model() + model.add_variables(coords=[pd.RangeIndex(5, name="i")], name="a") + model.add_variables(coords=[pd.RangeIndex(5, name="i")], name="b") + return model + + +@pytest.fixture +def a(m: Model) -> Variable: + return m.variables["a"] + + +@pytest.fixture +def b(m: Model) -> Variable: + return m.variables["b"] + + +# --------------------------------------------------------------------------- +# 3. Config validation +# --------------------------------------------------------------------------- + + +class TestConfigValidation: + def test_default_convention_is_legacy(self) -> None: + """Default arithmetic_convention should be 'legacy'.""" + fresh = OptionSettings( + display_max_rows=14, + display_max_terms=6, + arithmetic_convention="legacy", + ) + assert fresh["arithmetic_convention"] == "legacy" + + def test_set_valid_convention_v1(self) -> None: + old = options["arithmetic_convention"] + try: + options["arithmetic_convention"] = "v1" + assert options["arithmetic_convention"] == "v1" + finally: + options["arithmetic_convention"] = old + + def test_set_valid_convention_legacy(self) -> None: + old = options["arithmetic_convention"] + try: + options["arithmetic_convention"] = "legacy" + assert options["arithmetic_convention"] == "legacy" + finally: + options["arithmetic_convention"] = old + + def test_set_invalid_convention_raises(self) -> None: + with pytest.raises(ValueError, match="Invalid arithmetic_convention"): + options["arithmetic_convention"] = "invalid" + + def test_set_invalid_convention_exact_raises(self) -> None: + """'exact' is a join mode, not a valid convention name.""" + with pytest.raises(ValueError, match="Invalid arithmetic_convention"): + options["arithmetic_convention"] = "exact" + + def test_invalid_key_raises(self) -> None: + with pytest.raises(KeyError, match="not a valid setting"): + options["nonexistent_key"] = 42 + + def test_get_invalid_key_raises(self) -> None: + with pytest.raises(KeyError, match="not a valid setting"): + _ = options["nonexistent_key"] + + +# --------------------------------------------------------------------------- +# 5. Deprecation warnings +# --------------------------------------------------------------------------- + + +@pytest.mark.legacy_only +class TestDeprecationWarnings: + def test_add_constant_emits_deprecation_warning(self, a: Variable) -> None: + const = xr.DataArray([1, 2, 3, 4, 5], dims=["i"], coords={"i": range(5)}) + with pytest.warns(LinopyDeprecationWarning, match="legacy"): + _ = (1 * a) + const + + def test_mul_constant_emits_deprecation_warning(self, a: Variable) -> None: + const = xr.DataArray([1, 2, 3, 4, 5], dims=["i"], coords={"i": range(5)}) + with pytest.warns(LinopyDeprecationWarning, match="legacy"): + _ = (1 * a) * const + + def test_align_emits_deprecation_warning(self, a: Variable) -> None: + alpha = xr.DataArray([1, 2], [[1, 2]]) + with pytest.warns(LinopyDeprecationWarning, match="legacy"): + linopy.align(a, alpha) + + +@pytest.mark.v1_only +class TestNoDeprecationWarnings: + """V1: matching-coord operations should not emit deprecation warnings.""" + + def test_add_constant_no_deprecation_warning(self, a: Variable) -> None: + const = xr.DataArray([1, 2, 3, 4, 5], dims=["i"], coords={"i": range(5)}) + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("error", LinopyDeprecationWarning) + _ = (1 * a) + const + + def test_mul_constant_no_deprecation_warning(self, a: Variable) -> None: + const = xr.DataArray([1, 2, 3, 4, 5], dims=["i"], coords={"i": range(5)}) + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("error", LinopyDeprecationWarning) + _ = (1 * a) * const + + def test_align_no_deprecation_warning(self, a: Variable) -> None: + alpha = xr.DataArray([1, 2, 3, 4, 5], dims=["i"], coords={"i": range(5)}) + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("error", LinopyDeprecationWarning) + linopy.align(a, alpha) + + +# --------------------------------------------------------------------------- +# 6. Scalar fast path +# --------------------------------------------------------------------------- + + +class TestScalarFastPath: + """Scalar operations should produce same results as array operations.""" + + def test_add_scalar_matches_array(self, a: Variable) -> None: + scalar_result = (1 * a) + 5 + array_const = xr.DataArray(np.full(5, 5.0), dims=["i"], coords={"i": range(5)}) + array_result = (1 * a) + array_const + assert_linequal(scalar_result, array_result) + + def test_sub_scalar_matches_array(self, a: Variable) -> None: + scalar_result = (1 * a) - 3 + array_const = xr.DataArray(np.full(5, 3.0), dims=["i"], coords={"i": range(5)}) + array_result = (1 * a) - array_const + assert_linequal(scalar_result, array_result) + + def test_mul_scalar_matches_array(self, a: Variable) -> None: + scalar_result = (1 * a) * 2 + array_const = xr.DataArray(np.full(5, 2.0), dims=["i"], coords={"i": range(5)}) + array_result = (1 * a) * array_const + assert_linequal(scalar_result, array_result) + + def test_div_scalar_matches_array(self, a: Variable) -> None: + scalar_result = (1 * a) / 4 + array_const = xr.DataArray(np.full(5, 4.0), dims=["i"], coords={"i": range(5)}) + array_result = (1 * a) / array_const + assert_linequal(scalar_result, array_result) + + +# --------------------------------------------------------------------------- +# 7. NaN edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.v1_only +class TestNaNEdgeCases: + def test_inf_add_propagates(self, a: Variable) -> None: + """Adding inf should propagate to const.""" + const = xr.DataArray( + [1.0, np.inf, 3.0, 4.0, 5.0], dims=["i"], coords={"i": range(5)} + ) + result = (1 * a) + const + assert np.isinf(result.const.values[1]) + + def test_neg_inf_add_propagates(self, a: Variable) -> None: + """Adding -inf should propagate to const.""" + const = xr.DataArray( + [1.0, -np.inf, 3.0, 4.0, 5.0], dims=["i"], coords={"i": range(5)} + ) + result = (1 * a) + const + assert np.isinf(result.const.values[1]) + assert result.const.values[1] < 0 + + def test_inf_mul_propagates(self, a: Variable) -> None: + """Multiplying by inf should propagate to coeffs.""" + const = xr.DataArray( + [1.0, np.inf, 3.0, 4.0, 5.0], dims=["i"], coords={"i": range(5)} + ) + result = (1 * a) * const + assert np.isinf(result.coeffs.squeeze().values[1]) + + def test_nan_mul_raises_v1(self, a: Variable) -> None: + """Under v1, NaN in mul should raise ValueError.""" + const = xr.DataArray( + [1.0, np.nan, 3.0, 4.0, 5.0], dims=["i"], coords={"i": range(5)} + ) + with pytest.raises(ValueError, match="NaN"): + (1 * a) * const + + +# --------------------------------------------------------------------------- +# 8. Convention switching mid-session +# --------------------------------------------------------------------------- + + +class TestConventionSwitching: + def test_switch_convention_mid_session(self, a: Variable, b: Variable) -> None: + """Switching convention mid-session should change behavior immediately.""" + const = xr.DataArray([1, 2, 3], dims=["i"], coords={"i": [0, 1, 2]}) + + # Under legacy: mismatched-size const should work + linopy.options["arithmetic_convention"] = "legacy" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", LinopyDeprecationWarning) + # This should succeed under legacy (left join / override) + _ = (1 * a) + const + + # Switch to v1: same operation with mismatched coords should raise + linopy.options["arithmetic_convention"] = "v1" + with pytest.raises(ValueError, match="exact"): + _ = (1 * a) + const + + def test_reset_restores_defaults(self) -> None: + """OptionSettings.reset() should restore factory defaults.""" + options["arithmetic_convention"] = "v1" + assert options["arithmetic_convention"] == "v1" + options.reset() + assert options["arithmetic_convention"] == "legacy" # factory default + + +# --------------------------------------------------------------------------- +# 9. TestJoinParameter deduplication (shared base class) +# --------------------------------------------------------------------------- +# The existing TestJoinParameter class already tests both conventions via +# @pytest.mark.legacy_only / @pytest.mark.v1_only markers. The deduplication +# is addressed by verifying that explicit join= works identically under both. + + +class TestJoinWorksUnderBothConventions: + """Explicit join= should produce same results regardless of convention.""" + + @pytest.fixture + def m2(self) -> Model: + m = Model() + m.add_variables(coords=[pd.Index([0, 1, 2], name="i")], name="a") + m.add_variables(coords=[pd.Index([1, 2, 3], name="i")], name="b") + return m + + def test_add_inner_same_under_both(self, m2: Model) -> None: + a = m2.variables["a"] + b = m2.variables["b"] + + linopy.options["arithmetic_convention"] = "legacy" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", LinopyDeprecationWarning) + result_legacy = a.to_linexpr().add(b.to_linexpr(), join="inner") + + linopy.options["arithmetic_convention"] = "v1" + result_v1 = a.to_linexpr().add(b.to_linexpr(), join="inner") + + assert list(result_legacy.data.indexes["i"]) == list( + result_v1.data.indexes["i"] + ) + + def test_add_outer_same_under_both(self, m2: Model) -> None: + a = m2.variables["a"] + b = m2.variables["b"] + + linopy.options["arithmetic_convention"] = "legacy" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", LinopyDeprecationWarning) + result_legacy = a.to_linexpr().add(b.to_linexpr(), join="outer") + + linopy.options["arithmetic_convention"] = "v1" + result_v1 = a.to_linexpr().add(b.to_linexpr(), join="outer") + + assert set(result_legacy.data.indexes["i"]) == set(result_v1.data.indexes["i"]) + + +# --------------------------------------------------------------------------- +# 10. Error message tests +# --------------------------------------------------------------------------- + + +@pytest.mark.v1_only +class TestErrorMessages: + def test_exact_join_error_suggests_escape_hatches(self, a: Variable) -> None: + """Error message should suggest .add()/.mul() with join= parameter.""" + subset = xr.DataArray([1, 2, 3], dims=["i"], coords={"i": [0, 1, 2]}) + with pytest.raises(ValueError, match=r"\.add\(other, join="): + _ = (1 * a) + subset + + def test_exact_join_error_mentions_inner(self, a: Variable) -> None: + subset = xr.DataArray([1, 2, 3], dims=["i"], coords={"i": [0, 1, 2]}) + with pytest.raises(ValueError, match="inner"): + _ = (1 * a) + subset + + def test_exact_join_error_mentions_outer(self, a: Variable) -> None: + subset = xr.DataArray([1, 2, 3], dims=["i"], coords={"i": [0, 1, 2]}) + with pytest.raises(ValueError, match="outer"): + _ = (1 * a) + subset + + +# --------------------------------------------------------------------------- +# Variable.reindex() and Variable.reindex_like() +# --------------------------------------------------------------------------- + + +class TestVariableReindex: + @pytest.fixture + def var(self) -> Variable: + m = Model() + return m.add_variables(coords=[pd.Index([0, 1, 2, 3, 4], name="i")], name="v") + + def test_reindex_subset(self, var: Variable) -> None: + result = var.reindex(i=[1, 2, 3]) + assert isinstance(result, Variable) + assert list(result.data.indexes["i"]) == [1, 2, 3] + # Labels for the reindexed positions should be valid + assert (result.labels.sel(i=[1, 2, 3]).values >= 0).all() + + def test_reindex_superset(self, var: Variable) -> None: + result = var.reindex(i=[0, 1, 2, 3, 4, 5, 6]) + assert isinstance(result, Variable) + assert list(result.data.indexes["i"]) == [0, 1, 2, 3, 4, 5, 6] + # New positions should have sentinel label (-1) + assert result.labels.sel(i=5).item() == -1 + assert result.labels.sel(i=6).item() == -1 + # Original positions should be valid + assert (result.labels.sel(i=[0, 1, 2, 3, 4]).values >= 0).all() + + def test_reindex_preserves_type(self, var: Variable) -> None: + result = var.reindex(i=[0, 1]) + assert type(result) is type(var) + + def test_reindex_like_variable(self, var: Variable) -> None: + m = var.model + other = m.add_variables(coords=[pd.Index([2, 3, 4, 5], name="i")], name="other") + result = var.reindex_like(other) + assert isinstance(result, Variable) + assert list(result.data.indexes["i"]) == [2, 3, 4, 5] + # Position 5 should have sentinel + assert result.labels.sel(i=5).item() == -1 + # Positions 2,3,4 should be valid + assert (result.labels.sel(i=[2, 3, 4]).values >= 0).all() + + def test_reindex_like_dataarray(self, var: Variable) -> None: + other = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 3, 5]}) + result = var.reindex_like(other) + assert isinstance(result, Variable) + assert list(result.data.indexes["i"]) == [1, 3, 5] + assert result.labels.sel(i=5).item() == -1 + + def test_reindex_empty(self, var: Variable) -> None: + result = var.reindex(i=[]) + assert isinstance(result, Variable) + assert len(result.data.indexes["i"]) == 0 + + +class TestExpressionReindex: + @pytest.fixture + def expr(self) -> LinearExpression: + m = Model() + x = m.add_variables(coords=[pd.Index([0, 1, 2, 3, 4], name="i")], name="x") + return 2 * x + 10 + + def test_reindex_subset(self, expr: LinearExpression) -> None: + result = expr.reindex(i=[1, 2, 3]) + assert isinstance(result, LinearExpression) + assert list(result.data.indexes["i"]) == [1, 2, 3] + # Coefficients for existing positions should be preserved + np.testing.assert_array_equal(result.coeffs.squeeze().values, [2, 2, 2]) + np.testing.assert_array_equal(result.const.values, [10, 10, 10]) + + def test_reindex_superset(self, expr: LinearExpression) -> None: + result = expr.reindex(i=[0, 1, 2, 3, 4, 5, 6]) + assert isinstance(result, LinearExpression) + assert list(result.data.indexes["i"]) == [0, 1, 2, 3, 4, 5, 6] + # New positions should have sentinel var labels (-1) + assert result.vars.squeeze().sel(i=5).item() == -1 + assert result.vars.squeeze().sel(i=6).item() == -1 + # Original positions should be valid + assert (result.vars.squeeze().sel(i=[0, 1, 2, 3, 4]).values >= 0).all() + + def test_reindex_fill_value(self, expr: LinearExpression) -> None: + result = expr.reindex(i=[0, 1, 5], fill_value=0) + assert result.const.sel(i=5).item() == 0 + result_nan = expr.reindex(i=[0, 1, 5]) + assert np.isnan(result_nan.const.sel(i=5).item()) + + def test_reindex_preserves_type(self, expr: LinearExpression) -> None: + result = expr.reindex(i=[0, 1]) + assert type(result) is type(expr) + + def test_reindex_like_expression(self, expr: LinearExpression) -> None: + m = expr.model + y = m.add_variables(coords=[pd.Index([2, 3, 4, 5], name="i")], name="y") + other = 1 * y + result = expr.reindex_like(other) + assert isinstance(result, LinearExpression) + assert list(result.data.indexes["i"]) == [2, 3, 4, 5] + assert result.vars.squeeze().sel(i=5).item() == -1 + + def test_reindex_like_variable(self, expr: LinearExpression) -> None: + m = expr.model + y = m.add_variables(coords=[pd.Index([1, 3, 5], name="i")], name="y") + result = expr.reindex_like(y) + assert isinstance(result, LinearExpression) + assert list(result.data.indexes["i"]) == [1, 3, 5] + + def test_reindex_like_dataarray(self, expr: LinearExpression) -> None: + da = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 3, 5]}) + result = expr.reindex_like(da) + assert isinstance(result, LinearExpression) + assert list(result.data.indexes["i"]) == [1, 3, 5] + assert result.vars.squeeze().sel(i=5).item() == -1 + + def test_reindex_like_dataset(self, expr: LinearExpression) -> None: + ds = xr.Dataset({"tmp": (("i",), [1, 2])}, coords={"i": [0, 1]}) + result = expr.reindex_like(ds) + assert isinstance(result, LinearExpression) + assert list(result.data.indexes["i"]) == [0, 1] + + +class TestConstraintReindex: + @pytest.fixture + def con(self) -> Constraint: + m = Model() + x = m.add_variables(coords=[pd.Index([0, 1, 2, 3, 4], name="i")], name="x") + linopy.options["arithmetic_convention"] = "legacy" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", LinopyDeprecationWarning) + c = x >= 0 + m.add_constraints(c, name="c") + return m.constraints["c"] + + def test_reindex_subset(self, con: Constraint) -> None: + result = con.reindex({"i": [1, 2, 3]}) + assert list(result.data.indexes["i"]) == [1, 2, 3] + + def test_reindex_superset(self, con: Constraint) -> None: + result = con.reindex({"i": [0, 1, 2, 3, 4, 5]}) + assert list(result.data.indexes["i"]) == [0, 1, 2, 3, 4, 5] + # New position should have sentinel label + assert result.data.vars.squeeze().sel(i=5).item() == -1 + + def test_reindex_like_dataset(self, con: Constraint) -> None: + ds = xr.Dataset({"tmp": (("i",), [1, 2])}, coords={"i": [0, 1]}) + result = con.reindex_like(ds) + assert list(result.data.indexes["i"]) == [0, 1] + + def test_reindex_like_dataarray(self, con: Constraint) -> None: + da = xr.DataArray([10, 20], dims=["i"], coords={"i": [1, 3]}) + result = con.reindex_like(da) + assert list(result.data.indexes["i"]) == [1, 3] diff --git a/test/test_legacy_violations.py b/test/test_legacy_violations.py new file mode 100644 index 00000000..74c20e88 --- /dev/null +++ b/test/test_legacy_violations.py @@ -0,0 +1,452 @@ +""" +Legacy convention violations. + +Documents concrete bugs and surprising behaviors in the legacy arithmetic +convention. Each test class corresponds to a reported issue or PR and +contains paired legacy_only / v1_only tests showing: + +- **legacy_only** — what legacy actually does (the wrong/surprising behavior) +- **v1_only** — what v1 does instead (correct behavior, usually ValueError) + +This file serves as a living catalog of *why* the v1 convention exists. + +Related issues / PRs +==================== + +Positional alignment (override join): + #586 — Constraint RHS matched by position, not label + #550 — Silent data corruption with reordered coordinates + #257 — .loc[] reorder undone by override + +Subset / superset alignment (left join): + #572 — Non-associative arithmetic with constants + #569 — Variable vs Expression inconsistency + #571 — Multiplication with subset constant differs between paths + +User NaN silently swallowed: + #620 — NaN in user data filled with neutral elements + +Absent-slot NaN propagation: + #620 — Multiplication doesn't propagate absence in legacy +""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from linopy import Model +from linopy.variables import Variable + +# ============================================================ +# Fixtures +# ============================================================ + + +@pytest.fixture +def m() -> Model: + return Model() + + +@pytest.fixture +def time() -> pd.RangeIndex: + return pd.RangeIndex(5, name="time") + + +@pytest.fixture +def x(m: Model, time: pd.RangeIndex) -> Variable: + return m.add_variables(lower=0, coords=[time], name="x") + + +@pytest.fixture +def y(m: Model, time: pd.RangeIndex) -> Variable: + return m.add_variables(lower=0, coords=[time], name="y") + + +# ============================================================ +# 1. Positional alignment: same shape, different labels (#586, #550) +# ============================================================ + + +class TestPositionalAlignment: + """ + Legacy uses override (positional) join when operands have matching sizes. + This silently pairs values by array position, ignoring coordinate labels. + + Issues: #586, #550, #257 + """ + + @pytest.mark.legacy_only + def test_add_same_size_different_labels_silent(self, m: Model) -> None: + """ + #550: Two variables with same shape but different labels get + silently paired by position. Labels from the left operand win. + """ + time_a = pd.Index([0, 1, 2, 3, 4], name="time") + time_b = pd.Index([5, 6, 7, 8, 9], name="time") + a = m.add_variables(lower=0, coords=[time_a], name="a") + b = m.add_variables(lower=0, coords=[time_b], name="b") + + result = a + b + # Legacy: silently pairs a[0] with b[5], a[1] with b[6], etc. + # The result has a's labels (0-4), but b's variable IDs are from 5-9 + assert list(result.coords["time"].values) == [0, 1, 2, 3, 4] + # b's variables are present despite the labels saying time=0..4 + b_var_ids = b.labels.values + result_var_ids = result.vars.values[:, 1] # second term is b + np.testing.assert_array_equal(result_var_ids, b_var_ids) + + @pytest.mark.v1_only + def test_add_same_size_different_labels_raises(self, m: Model) -> None: + """v1: Mismatched labels raise ValueError with helpful message.""" + time_a = pd.Index([0, 1, 2, 3, 4], name="time") + time_b = pd.Index([5, 6, 7, 8, 9], name="time") + a = m.add_variables(lower=0, coords=[time_a], name="a") + b = m.add_variables(lower=0, coords=[time_b], name="b") + + with pytest.raises(ValueError, match="Coordinate mismatch"): + a + b + + @pytest.mark.legacy_only + def test_mul_reordered_labels_silent(self, m: Model) -> None: + """ + #550: Multiplying by a constant with reordered labels of the same + size silently uses positional alignment, producing wrong results. + """ + idx = pd.Index(["costs", "penalty"], name="effect") + v = m.add_variables(lower=0, coords=[idx], name="v") + # Reversed order — same labels, different positions + factors = xr.DataArray( + [2.0, 1.0], + dims=["effect"], + coords={"effect": pd.Index(["penalty", "costs"], name="effect")}, + ) + + result = v * factors + # Legacy: positional match → v["costs"] * 2.0, v["penalty"] * 1.0 + # But the user meant: v["costs"] * 1.0, v["penalty"] * 2.0 + assert result.coeffs.sel(effect="costs").item() == 2.0 # WRONG + assert result.coeffs.sel(effect="penalty").item() == 1.0 # WRONG + + @pytest.mark.v1_only + def test_mul_reordered_labels_raises(self, m: Model) -> None: + """v1: Reordered labels on same dim raise ValueError.""" + idx = pd.Index(["costs", "penalty"], name="effect") + v = m.add_variables(lower=0, coords=[idx], name="v") + factors = xr.DataArray( + [2.0, 1.0], + dims=["effect"], + coords={"effect": pd.Index(["penalty", "costs"], name="effect")}, + ) + + with pytest.raises(ValueError, match="exact"): + v * factors + + @pytest.mark.legacy_only + def test_add_reordered_labels_positional(self, m: Model) -> None: + """ + Same labels in different order: legacy silently uses positional + alignment on addition too, producing wrong constant values. + """ + idx_a = pd.Index(["A1", "A5", "A11", "A100"], name="item") + x = m.add_variables(lower=0, coords=[idx_a], name="x") + + # Same labels, different order, same size → override join + rhs = xr.DataArray( + [100.0, 1.0, 5.0, 11.0], + dims=["item"], + coords={"item": pd.Index(["A100", "A1", "A5", "A11"], name="item")}, + ) + result = x + rhs + # Legacy: positional match → A1 gets 100.0, A100 gets 11.0 + assert result.const.sel(item="A1").item() == 100.0 # WRONG + assert result.const.sel(item="A100").item() == 11.0 # WRONG + + @pytest.mark.v1_only + def test_add_reordered_labels_raises(self, m: Model) -> None: + """v1: Reordered labels raise ValueError.""" + idx_a = pd.Index(["A1", "A5", "A11", "A100"], name="item") + x = m.add_variables(lower=0, coords=[idx_a], name="x") + + rhs = xr.DataArray( + [100.0, 1.0, 5.0, 11.0], + dims=["item"], + coords={"item": pd.Index(["A100", "A1", "A5", "A11"], name="item")}, + ) + with pytest.raises(ValueError, match="exact"): + x + rhs + + +# ============================================================ +# 2. Subset constant breaks associativity (#572) +# ============================================================ + + +class TestSubsetConstantAssociativity: + """ + Legacy uses left-join when a constant has different-sized coordinates. + This drops coordinates that might be needed by a later operation, + breaking associativity: (a + c) + b != a + (c + b). + + Issue: #572 (review by @FBumann) + """ + + @pytest.mark.legacy_only + def test_add_order_matters(self, m: Model) -> None: + """ + Adding a subset constant first vs last gives different results + because left-join drops the constant's extra coordinates. + """ + time3 = pd.RangeIndex(3, name="time") + time5 = pd.RangeIndex(5, name="time") + a = m.add_variables(lower=0, coords=[time3], name="a") + b = m.add_variables(lower=0, coords=[time5], name="b") + factor = xr.DataArray( + [10.0, 20.0, 30.0, 40.0, 50.0], + dims=["time"], + coords={"time": time5}, + ) + + # a + factor + b: factor left-joined to a's coords (0,1,2), + # then merged with b (0..4). factor at time=3,4 is lost. + r1 = a + factor + b + # a + b + factor: a+b merged first (outer → 0..4), + # then factor left-joined to (0..4). factor at time=3,4 preserved. + r2 = a + b + factor + + # At time=3,4 the constant should be 40,50 but r1 loses them + assert r1.const.sel(time=3).item() == 0.0 # WRONG: lost + assert r2.const.sel(time=3).item() == 40.0 # correct + + @pytest.mark.v1_only + def test_subset_constant_raises(self, m: Model) -> None: + """v1: Subset constant on a shared dim raises ValueError.""" + time3 = pd.RangeIndex(3, name="time") + a = m.add_variables(lower=0, coords=[time3], name="a") + factor = xr.DataArray( + [10.0, 20.0, 30.0, 40.0, 50.0], + dims=["time"], + coords={"time": pd.RangeIndex(5, name="time")}, + ) + + with pytest.raises(ValueError, match="exact"): + a + factor + + +# ============================================================ +# 3. User NaN silently swallowed (#620) +# ============================================================ + + +class TestUserNaNSwallowed: + """ + Legacy silently fills NaN in user-supplied constants with neutral + elements: 0 for addition, 0 for multiplication (zeroes out variable), + 1 for division (leaves variable unchanged). The fill values are + inconsistent and hide data bugs. + + Issue: #620 + """ + + @pytest.fixture + def nan_data(self, time: pd.RangeIndex) -> xr.DataArray: + vals = np.array([1.0, np.nan, 3.0, 4.0, 5.0]) + return xr.DataArray(vals, dims=["time"], coords={"time": time}) + + @pytest.mark.legacy_only + def test_add_nan_silently_filled_with_zero( + self, x: Variable, nan_data: xr.DataArray + ) -> None: + """NaN in addend becomes 0 — user's missing data silently ignored.""" + result = x + nan_data + assert not np.isnan(result.const.values).any() + assert result.const.sel(time=1).item() == 0.0 # was NaN → 0 + + @pytest.mark.legacy_only + def test_mul_nan_silently_zeroes_variable( + self, x: Variable, nan_data: xr.DataArray + ) -> None: + """NaN in multiplier becomes 0 — variable silently zeroed out.""" + result = x * nan_data + assert not np.isnan(result.coeffs.squeeze().values).any() + assert result.coeffs.squeeze().sel(time=1).item() == 0.0 + + @pytest.mark.legacy_only + def test_div_nan_silently_becomes_one( + self, x: Variable, nan_data: xr.DataArray + ) -> None: + """ + NaN in divisor becomes 1 — variable silently left unchanged. + Note: inconsistent with mul which fills with 0. + """ + # Avoid division by zero at time=0 by using nan_data + 1 + divisor = nan_data.copy() + divisor[0] = 2.0 # avoid 1/0 + result = x / divisor + assert not np.isnan(result.coeffs.squeeze().values).any() + # time=1 had NaN → filled with 1 → coefficient unchanged + assert result.coeffs.squeeze().sel(time=1).item() == 1.0 + + @pytest.mark.v1_only + def test_add_nan_raises(self, x: Variable, nan_data: xr.DataArray) -> None: + """v1: NaN in user data raises ValueError.""" + with pytest.raises(ValueError, match="NaN"): + x + nan_data + + @pytest.mark.v1_only + def test_mul_nan_raises(self, x: Variable, nan_data: xr.DataArray) -> None: + with pytest.raises(ValueError, match="NaN"): + x * nan_data + + @pytest.mark.v1_only + def test_div_nan_raises(self, x: Variable, nan_data: xr.DataArray) -> None: + with pytest.raises(ValueError, match="NaN"): + x / nan_data + + @pytest.mark.legacy_only + def test_nan_fill_inconsistency(self, x: Variable, nan_data: xr.DataArray) -> None: + """ + Legacy fills NaN with DIFFERENT values per operation: + add→0, mul→0, div→1. This is internally inconsistent. + """ + add_result = x + nan_data + mul_result = x * nan_data + divisor = nan_data.copy() + divisor[0] = 2.0 + div_result = x / divisor + + nan_pos = 1 # time=1 has NaN in input + add_fill = add_result.const.sel(time=nan_pos).item() + mul_fill = mul_result.coeffs.squeeze().sel(time=nan_pos).item() + div_fill = div_result.coeffs.squeeze().sel(time=nan_pos).item() + + assert add_fill == 0.0 # additive "identity" + assert mul_fill == 0.0 # kills the variable + assert div_fill == 1.0 # leaves variable unchanged + # mul fills with 0 (destructive) but div fills with 1 (preserving) + # — no consistent principle + + +# ============================================================ +# 4. Variable vs Expression inconsistency (#569, #571) +# ============================================================ + + +class TestVariableExpressionInconsistency: + """ + Variable and Expression code paths produce different results for the + same mathematical operation. x * c and (1*x) * c should be identical. + + Issues: #569, #571 + """ + + @pytest.mark.legacy_only + def test_mul_subset_var_vs_expr_same_result(self, m: Model) -> None: + """ + Legacy: after the fix in #572, both paths produce the same result + (fill with 0). Before #572, the expression path crashed. + """ + coords = pd.RangeIndex(5, name="i") + x = m.add_variables(lower=0, coords=[coords], name="x") + subset = xr.DataArray([10.0, 30.0], dims=["i"], coords={"i": [1, 3]}) + + var_result = x * subset + expr_result = (1 * x) * subset + + # Both should produce identical coefficients + np.testing.assert_array_equal( + var_result.coeffs.squeeze().values, + expr_result.coeffs.squeeze().values, + ) + + @pytest.mark.v1_only + def test_mul_subset_both_raise(self, m: Model) -> None: + """v1: Both paths raise the same error.""" + coords = pd.RangeIndex(5, name="i") + x = m.add_variables(lower=0, coords=[coords], name="x") + subset = xr.DataArray([10.0, 30.0], dims=["i"], coords={"i": [1, 3]}) + + with pytest.raises(ValueError, match="exact"): + x * subset + with pytest.raises(ValueError, match="exact"): + (1 * x) * subset + + +# ============================================================ +# 5. Absent slot NaN not propagated in legacy (#620) +# ============================================================ + + +class TestAbsentSlotPropagation: + """ + Legacy does not mark absent variable slots as NaN in to_linexpr(), + so multiplication/division cannot distinguish 'absent' from 'zero'. + + Issue: #620 + """ + + @pytest.mark.legacy_only + def test_absent_times_scalar_becomes_zero(self, x: Variable) -> None: + """ + Legacy: absent slot * 3 becomes coeffs=3, const=0 (a valid + zero-contribution term). The absence is lost. + """ + xs = x.shift(time=1) # time=0 is absent + result = xs * 3 + # Legacy treats absent as zero → coeffs=3 * 0 = wait, actually + # labels=-1 but coeffs=3 (label -1 is unused but coeff not NaN) + assert not result.isnull().values[0] # NOT absent — this is wrong + + @pytest.mark.v1_only + def test_absent_times_scalar_stays_absent(self, x: Variable) -> None: + """v1: absent slot * 3 stays absent (NaN propagates).""" + xs = x.shift(time=1) + result = xs * 3 + assert result.isnull().values[0] # correctly absent + assert not result.isnull().values[1] # valid slot unaffected + + @pytest.mark.legacy_only + def test_absent_indistinguishable_from_zero(self, x: Variable) -> None: + """ + Legacy cannot tell apart an absent variable from a zero variable. + Both produce isnull()=False after multiplication. + """ + xs = x.shift(time=1) # time=0 is absent + result_absent = xs * 3 + result_zero = x * 0 # genuinely zero + + # Both look non-null under legacy — information lost + assert not result_absent.isnull().values[0] + assert not result_zero.isnull().values[0] + + @pytest.mark.v1_only + def test_absent_distinguishable_from_zero(self, x: Variable) -> None: + """v1: absent and zero are distinct.""" + xs = x.shift(time=1) + result_absent = xs * 3 + result_zero = x * 0 + + assert result_absent.isnull().values[0] # absent + assert not result_zero.isnull().values[0] # zero but present + + @pytest.mark.legacy_only + def test_fillna_noop_on_absent_variable(self, x: Variable) -> None: + """ + Legacy: fillna(42) on a shifted variable does nothing because + to_linexpr() doesn't produce NaN to fill. + """ + xs = x.shift(time=1) + result = xs.fillna(42) + # The absent slot at time=0 has const=0 (not NaN), so fillna + # has nothing to replace + assert result.const.values[0] == 0.0 # should be 42 + + @pytest.mark.v1_only + def test_fillna_works_on_absent_variable(self, x: Variable) -> None: + """v1: fillna(42) correctly fills the absent slot.""" + xs = x.shift(time=1) + result = xs.fillna(42) + assert result.const.values[0] == 42.0 + assert result.const.values[1] == 0.0 # valid slot unchanged diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index d3b8d426..90163e4b 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -14,6 +14,7 @@ import polars as pl import pytest import xarray as xr +from xarray.core.types import JoinOptions from xarray.testing import assert_equal from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge @@ -387,6 +388,7 @@ def test_linear_expression_substraction( assert res.data.notnull().all().to_array().all() +@pytest.mark.legacy_only def test_linear_expression_sum( x: Variable, y: Variable, z: Variable, v: Variable ) -> None: @@ -403,12 +405,42 @@ def test_linear_expression_sum( assert_linequal(expr.sum(["dim_0", TERM_DIM]), expr.sum("dim_0")) - # test special case otherride coords + # test special case otherride coords (legacy outer join allows this) expr = v.loc[:9] + v.loc[10:] assert expr.nterm == 2 assert len(expr.coords["dim_2"]) == 10 +@pytest.mark.v1_only +def test_linear_expression_sum_v1( + x: Variable, y: Variable, z: Variable, v: Variable +) -> None: + expr = 10 * x + y + z + res = expr.sum("dim_0") + + assert res.size == expr.size + assert res.nterm == expr.nterm * len(expr.data.dim_0) + + res = expr.sum() + assert res.size == expr.size + assert res.nterm == expr.size + assert res.data.notnull().all().to_array().all() + + assert_linequal(expr.sum(["dim_0", TERM_DIM]), expr.sum("dim_0")) + + # v1: mismatched coords raise ValueError + with pytest.raises(ValueError, match="Coordinate mismatch"): + v.loc[:9] + v.loc[10:] + + # explicit assign_coords as workaround + a = v.loc[:9] + b = v.loc[10:].assign_coords(dim_2=a.coords["dim_2"]) + expr = a + b + assert expr.nterm == 2 + assert len(expr.coords["dim_2"]) == 10 + + +@pytest.mark.legacy_only def test_linear_expression_sum_with_const( x: Variable, y: Variable, z: Variable, v: Variable ) -> None: @@ -427,12 +459,43 @@ def test_linear_expression_sum_with_const( assert_linequal(expr.sum(["dim_0", TERM_DIM]), expr.sum("dim_0")) - # test special case otherride coords + # test special case otherride coords (legacy outer join allows this) expr = v.loc[:9] + v.loc[10:] assert expr.nterm == 2 assert len(expr.coords["dim_2"]) == 10 +@pytest.mark.v1_only +def test_linear_expression_sum_with_const_v1( + x: Variable, y: Variable, z: Variable, v: Variable +) -> None: + expr = 10 * x + y + z + 10 + res = expr.sum("dim_0") + + assert res.size == expr.size + assert res.nterm == expr.nterm * len(expr.data.dim_0) + assert (res.const == 20).all() + + res = expr.sum() + assert res.size == expr.size + assert res.nterm == expr.size + assert res.data.notnull().all().to_array().all() + assert (res.const == 60).item() + + assert_linequal(expr.sum(["dim_0", TERM_DIM]), expr.sum("dim_0")) + + # v1: mismatched coords raise ValueError + with pytest.raises(ValueError, match="Coordinate mismatch"): + v.loc[:9] + v.loc[10:] + + # explicit assign_coords as workaround + a = v.loc[:9] + b = v.loc[10:].assign_coords(dim_2=a.coords["dim_2"]) + expr = a + b + assert expr.nterm == 2 + assert len(expr.coords["dim_2"]) == 10 + + def test_linear_expression_sum_drop_zeros(z: Variable) -> None: coeff = xr.zeros_like(z.labels) coeff[1, 0] = 3 @@ -538,6 +601,8 @@ def test_linear_expression_multiplication_invalid( class TestCoordinateAlignment: + """Coordinate alignment tests for both legacy (outer join) and v1 (exact join).""" + @pytest.fixture(params=["da", "series"]) def subset(self, request: Any) -> xr.DataArray | pd.Series: if request.param == "da": @@ -574,6 +639,7 @@ def nan_constant(self, request: Any) -> xr.DataArray | pd.Series: return pd.Series(vals, index=pd.Index(range(20), name="dim_2")) class TestSubset: + @pytest.mark.legacy_only @pytest.mark.parametrize("operand", ["var", "expr"]) def test_mul_subset_fills_zeros( self, @@ -588,6 +654,16 @@ def test_mul_subset_fills_zeros( assert not np.isnan(result.coeffs.values).any() np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + @pytest.mark.v1_only + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_mul_subset_raises( + self, v: Variable, subset: xr.DataArray, operand: str + ) -> None: + target = v if operand == "var" else 1 * v + with pytest.raises(ValueError, match="exact"): + target * subset + + @pytest.mark.legacy_only @pytest.mark.parametrize("operand", ["var", "expr"]) def test_add_subset_fills_zeros( self, @@ -606,6 +682,16 @@ def test_add_subset_fills_zeros( assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, expected) + @pytest.mark.v1_only + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_add_subset_raises( + self, v: Variable, subset: xr.DataArray, operand: str + ) -> None: + target = v if operand == "var" else v + 5 + with pytest.raises(ValueError, match="exact"): + target + subset + + @pytest.mark.legacy_only @pytest.mark.parametrize("operand", ["var", "expr"]) def test_sub_subset_fills_negated( self, @@ -624,6 +710,16 @@ def test_sub_subset_fills_negated( assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, expected) + @pytest.mark.v1_only + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_sub_subset_raises( + self, v: Variable, subset: xr.DataArray, operand: str + ) -> None: + target = v if operand == "var" else v + 5 + with pytest.raises(ValueError, match="exact"): + target - subset + + @pytest.mark.legacy_only @pytest.mark.parametrize("operand", ["var", "expr"]) def test_div_subset_inverts_nonzero( self, v: Variable, subset: xr.DataArray, operand: str @@ -635,19 +731,52 @@ def test_div_subset_inverts_nonzero( assert result.coeffs.squeeze().sel(dim_2=1).item() == pytest.approx(0.1) assert result.coeffs.squeeze().sel(dim_2=0).item() == pytest.approx(1.0) + @pytest.mark.v1_only + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_div_subset_raises( + self, v: Variable, subset: xr.DataArray, operand: str + ) -> None: + target = v if operand == "var" else 1 * v + with pytest.raises(ValueError, match="exact"): + target / subset + + @pytest.mark.legacy_only def test_subset_add_var_coefficients( self, v: Variable, subset: xr.DataArray ) -> None: result = subset + v np.testing.assert_array_equal(result.coeffs.squeeze().values, np.ones(20)) + @pytest.mark.v1_only + def test_subset_add_var_raises(self, v: Variable, subset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + subset + v + + @pytest.mark.legacy_only def test_subset_sub_var_coefficients( self, v: Variable, subset: xr.DataArray ) -> None: result = subset - v np.testing.assert_array_equal(result.coeffs.squeeze().values, -np.ones(20)) + @pytest.mark.v1_only + def test_subset_sub_var_raises(self, v: Variable, subset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + subset - v + + @pytest.mark.v1_only + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_mul_subset_join_left( + self, v: Variable, subset: xr.DataArray, operand: str + ) -> None: + """Explicit join='left' fills zeros for missing coords.""" + target = v if operand == "var" else 1 * v + result = target.mul(subset, join="left") + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + class TestSuperset: + @pytest.mark.legacy_only def test_add_superset_pins_to_lhs_coords( self, v: Variable, superset: xr.DataArray ) -> None: @@ -655,15 +784,51 @@ def test_add_superset_pins_to_lhs_coords( assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.const.values).any() + @pytest.mark.v1_only + def test_add_superset_raises(self, v: Variable, superset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + v + superset + + @pytest.mark.legacy_only def test_add_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: assert_linequal(superset + v, v + superset) + @pytest.mark.v1_only + def test_add_var_commutative_raises( + self, v: Variable, superset: xr.DataArray + ) -> None: + with pytest.raises(ValueError, match="exact"): + superset + v + with pytest.raises(ValueError, match="exact"): + v + superset + + @pytest.mark.legacy_only def test_sub_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: assert_linequal(superset - v, -v + superset) + @pytest.mark.v1_only + def test_sub_var_commutative_raises( + self, v: Variable, superset: xr.DataArray + ) -> None: + with pytest.raises(ValueError, match="exact"): + superset - v + with pytest.raises(ValueError, match="exact"): + v - superset + + @pytest.mark.legacy_only def test_mul_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: assert_linequal(superset * v, v * superset) + @pytest.mark.v1_only + def test_mul_var_commutative_raises( + self, v: Variable, superset: xr.DataArray + ) -> None: + with pytest.raises(ValueError, match="exact"): + superset * v + with pytest.raises(ValueError, match="exact"): + v * superset + + @pytest.mark.legacy_only def test_mul_superset_pins_to_lhs_coords( self, v: Variable, superset: xr.DataArray ) -> None: @@ -671,6 +836,12 @@ def test_mul_superset_pins_to_lhs_coords( assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() + @pytest.mark.v1_only + def test_mul_superset_raises(self, v: Variable, superset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + v * superset + + @pytest.mark.legacy_only def test_div_superset_pins_to_lhs_coords(self, v: Variable) -> None: superset_nonzero = xr.DataArray( np.arange(1, 26, dtype=float), @@ -681,7 +852,18 @@ def test_div_superset_pins_to_lhs_coords(self, v: Variable) -> None: assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() + @pytest.mark.v1_only + def test_div_superset_raises(self, v: Variable) -> None: + superset_nonzero = xr.DataArray( + np.arange(1, 26, dtype=float), + dims=["dim_2"], + coords={"dim_2": range(25)}, + ) + with pytest.raises(ValueError, match="exact"): + v / superset_nonzero + class TestDisjoint: + @pytest.mark.legacy_only def test_add_disjoint_fills_zeros(self, v: Variable) -> None: disjoint = xr.DataArray( [100.0, 200.0], dims=["dim_2"], coords={"dim_2": [50, 60]} @@ -691,6 +873,15 @@ def test_add_disjoint_fills_zeros(self, v: Variable) -> None: assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, np.zeros(20)) + @pytest.mark.v1_only + def test_add_disjoint_raises(self, v: Variable) -> None: + disjoint = xr.DataArray( + [100.0, 200.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + with pytest.raises(ValueError, match="exact"): + v + disjoint + + @pytest.mark.legacy_only def test_mul_disjoint_fills_zeros(self, v: Variable) -> None: disjoint = xr.DataArray( [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} @@ -700,6 +891,15 @@ def test_mul_disjoint_fills_zeros(self, v: Variable) -> None: assert not np.isnan(result.coeffs.values).any() np.testing.assert_array_equal(result.coeffs.squeeze().values, np.zeros(20)) + @pytest.mark.v1_only + def test_mul_disjoint_raises(self, v: Variable) -> None: + disjoint = xr.DataArray( + [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + with pytest.raises(ValueError, match="exact"): + v * disjoint + + @pytest.mark.legacy_only def test_div_disjoint_preserves_coeffs(self, v: Variable) -> None: disjoint = xr.DataArray( [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} @@ -709,7 +909,16 @@ def test_div_disjoint_preserves_coeffs(self, v: Variable) -> None: assert not np.isnan(result.coeffs.values).any() np.testing.assert_array_equal(result.coeffs.squeeze().values, np.ones(20)) + @pytest.mark.v1_only + def test_div_disjoint_raises(self, v: Variable) -> None: + disjoint = xr.DataArray( + [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + with pytest.raises(ValueError, match="exact"): + v / disjoint + class TestCommutativity: + @pytest.mark.legacy_only @pytest.mark.parametrize( "make_lhs,make_rhs", [ @@ -729,18 +938,37 @@ def test_commutativity( ) -> None: assert_linequal(make_lhs(v, subset), make_rhs(v, subset)) + @pytest.mark.v1_only + @pytest.mark.parametrize( + "op", + [ + lambda v, s: s * v, + lambda v, s: s + v, + lambda v, s: s + (v + 5), + lambda v, s: s - v, + ], + ids=["subset*var", "subset+var", "subset+expr", "subset-var"], + ) + def test_commutativity_raises( + self, v: Variable, subset: xr.DataArray, op: Any + ) -> None: + with pytest.raises(ValueError, match="exact"): + op(v, subset) + + @pytest.mark.legacy_only def test_sub_var_anticommutative( self, v: Variable, subset: xr.DataArray ) -> None: assert_linequal(subset - v, -v + subset) + @pytest.mark.legacy_only def test_sub_expr_anticommutative( self, v: Variable, subset: xr.DataArray ) -> None: expr = v + 5 assert_linequal(subset - expr, -(expr - subset)) - def test_add_commutativity_full_coords(self, v: Variable) -> None: + def test_add_commutativity_matching_coords(self, v: Variable) -> None: full = xr.DataArray( np.arange(20, dtype=float), dims=["dim_2"], @@ -748,8 +976,18 @@ def test_add_commutativity_full_coords(self, v: Variable) -> None: ) assert_linequal(v + full, full + v) + @pytest.mark.v1_only + def test_subset_raises_both_sides( + self, v: Variable, subset: xr.DataArray + ) -> None: + with pytest.raises(ValueError, match="exact"): + v * subset + with pytest.raises(ValueError, match="exact"): + subset * v + class TestQuadratic: - def test_quadexpr_add_subset( + @pytest.mark.legacy_only + def test_quadexpr_add_subset_fills( self, v: Variable, subset: xr.DataArray, @@ -762,6 +1000,15 @@ def test_quadexpr_add_subset( assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, expected_fill) + @pytest.mark.v1_only + def test_quadexpr_add_subset_raises( + self, v: Variable, subset: xr.DataArray + ) -> None: + qexpr = v * v + with pytest.raises(ValueError, match="exact"): + qexpr + subset + + @pytest.mark.legacy_only def test_quadexpr_sub_subset( self, v: Variable, @@ -775,7 +1022,16 @@ def test_quadexpr_sub_subset( assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, -expected_fill) - def test_quadexpr_mul_subset( + @pytest.mark.v1_only + def test_quadexpr_sub_subset_raises( + self, v: Variable, subset: xr.DataArray + ) -> None: + qexpr = v * v + with pytest.raises(ValueError, match="exact"): + qexpr - subset + + @pytest.mark.legacy_only + def test_quadexpr_mul_subset_fills( self, v: Variable, subset: xr.DataArray, @@ -788,6 +1044,15 @@ def test_quadexpr_mul_subset( assert not np.isnan(result.coeffs.values).any() np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + @pytest.mark.v1_only + def test_quadexpr_mul_subset_raises( + self, v: Variable, subset: xr.DataArray + ) -> None: + qexpr = v * v + with pytest.raises(ValueError, match="exact"): + qexpr * subset + + @pytest.mark.legacy_only def test_subset_mul_quadexpr( self, v: Variable, @@ -801,22 +1066,38 @@ def test_subset_mul_quadexpr( assert not np.isnan(result.coeffs.values).any() np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + @pytest.mark.v1_only + def test_subset_mul_quadexpr_raises( + self, v: Variable, subset: xr.DataArray + ) -> None: + qexpr = v * v + with pytest.raises(ValueError, match="exact"): + subset * qexpr + + @pytest.mark.legacy_only def test_subset_add_quadexpr(self, v: Variable, subset: xr.DataArray) -> None: qexpr = v * v assert_quadequal(subset + qexpr, qexpr + subset) + @pytest.mark.v1_only + def test_subset_add_quadexpr_raises( + self, v: Variable, subset: xr.DataArray + ) -> None: + qexpr = v * v + with pytest.raises(ValueError, match="exact"): + subset + qexpr + class TestMissingValues: """ Same shape as variable but with NaN entries in the constant. - NaN values are filled with operation-specific neutral elements: - - Addition/subtraction: NaN -> 0 (additive identity) - - Multiplication: NaN -> 0 (zeroes out the variable) - - Division: NaN -> 1 (multiplicative identity, no scaling) + Legacy: NaN values are filled with operation-specific neutral elements. + V1: NaN values propagate (no implicit fillna). """ NAN_POSITIONS = [0, 5, 19] + @pytest.mark.legacy_only @pytest.mark.parametrize("operand", ["var", "expr"]) def test_add_nan_filled( self, @@ -833,6 +1114,21 @@ def test_add_nan_filled( for i in self.NAN_POSITIONS: assert result.const.values[i] == base_const + @pytest.mark.v1_only + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_add_nan_raises(self, v: Variable, operand: str) -> None: + vals = np.arange(20, dtype=float) + vals[0] = np.nan + vals[5] = np.nan + vals[19] = np.nan + nan_constant = xr.DataArray( + vals, dims=["dim_2"], coords={"dim_2": range(20)} + ) + target = v if operand == "var" else v + 5 + with pytest.raises(ValueError, match="NaN"): + target + nan_constant + + @pytest.mark.legacy_only @pytest.mark.parametrize("operand", ["var", "expr"]) def test_sub_nan_filled( self, @@ -849,6 +1145,20 @@ def test_sub_nan_filled( for i in self.NAN_POSITIONS: assert result.const.values[i] == base_const + @pytest.mark.v1_only + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_sub_nan_raises(self, v: Variable, operand: str) -> None: + vals = np.arange(20, dtype=float) + for i in self.NAN_POSITIONS: + vals[i] = np.nan + nan_constant = xr.DataArray( + vals, dims=["dim_2"], coords={"dim_2": range(20)} + ) + target = v if operand == "var" else v + 5 + with pytest.raises(ValueError, match="NaN"): + target - nan_constant + + @pytest.mark.legacy_only @pytest.mark.parametrize("operand", ["var", "expr"]) def test_mul_nan_filled( self, @@ -864,6 +1174,19 @@ def test_mul_nan_filled( for i in self.NAN_POSITIONS: assert result.coeffs.squeeze().values[i] == 0.0 + @pytest.mark.v1_only + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_mul_nan_raises(self, v: Variable, operand: str) -> None: + vals = np.arange(20, dtype=float) + vals[0] = np.nan + nan_constant = xr.DataArray( + vals, dims=["dim_2"], coords={"dim_2": range(20)} + ) + target = v if operand == "var" else 1 * v + with pytest.raises(ValueError, match="NaN"): + target * nan_constant + + @pytest.mark.legacy_only @pytest.mark.parametrize("operand", ["var", "expr"]) def test_div_nan_filled( self, @@ -880,6 +1203,20 @@ def test_div_nan_filled( for i in self.NAN_POSITIONS: assert result.coeffs.squeeze().values[i] == original_coeffs[i] + @pytest.mark.v1_only + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_div_nan_raises(self, v: Variable, operand: str) -> None: + vals = np.arange(20, dtype=float) + 1 + vals[0] = np.nan + vals[5] = np.nan + nan_constant = xr.DataArray( + vals, dims=["dim_2"], coords={"dim_2": range(20)} + ) + target = v if operand == "var" else 1 * v + with pytest.raises(ValueError, match="NaN"): + target / nan_constant + + @pytest.mark.legacy_only def test_add_commutativity( self, v: Variable, @@ -894,6 +1231,19 @@ def test_add_commutativity( result_a.coeffs.values, result_b.coeffs.values ) + @pytest.mark.v1_only + def test_add_commutativity_nan_raises(self, v: Variable) -> None: + vals = np.arange(20, dtype=float) + vals[0] = np.nan + nan_constant = xr.DataArray( + vals, dims=["dim_2"], coords={"dim_2": range(20)} + ) + with pytest.raises(ValueError, match="NaN"): + v + nan_constant + with pytest.raises(ValueError, match="NaN"): + nan_constant + v + + @pytest.mark.legacy_only def test_mul_commutativity( self, v: Variable, @@ -907,6 +1257,19 @@ def test_mul_commutativity( result_a.coeffs.values, result_b.coeffs.values ) + @pytest.mark.v1_only + def test_mul_commutativity_nan_raises(self, v: Variable) -> None: + vals = np.arange(20, dtype=float) + vals[0] = np.nan + nan_constant = xr.DataArray( + vals, dims=["dim_2"], coords={"dim_2": range(20)} + ) + with pytest.raises(ValueError, match="NaN"): + v * nan_constant + with pytest.raises(ValueError, match="NaN"): + nan_constant * v + + @pytest.mark.legacy_only def test_quadexpr_add_nan( self, v: Variable, @@ -918,21 +1281,54 @@ def test_quadexpr_add_nan( assert result.sizes["dim_2"] == 20 assert not np.isnan(result.const.values).any() + @pytest.mark.v1_only + def test_quadexpr_add_nan_raises(self, v: Variable) -> None: + vals = np.arange(20, dtype=float) + vals[0] = np.nan + nan_constant = xr.DataArray( + vals, dims=["dim_2"], coords={"dim_2": range(20)} + ) + qexpr = v * v + with pytest.raises(ValueError, match="NaN"): + qexpr + nan_constant + class TestExpressionWithNaN: - """Test that NaN in expression's own const/coeffs doesn't propagate.""" + """ + Test NaN in expression's own const/coeffs. - def test_shifted_expr_add_scalar(self, v: Variable) -> None: + Legacy: NaN is filled with neutral elements. + V1: NaN propagates. + """ + + @pytest.mark.legacy_only + def test_shifted_expr_add_scalar_filled(self, v: Variable) -> None: expr = (1 * v).shift(dim_2=1) result = expr + 5 assert not np.isnan(result.const.values).any() assert result.const.values[0] == 5.0 - def test_shifted_expr_mul_scalar(self, v: Variable) -> None: + @pytest.mark.v1_only + def test_shifted_expr_add_scalar_revives(self, v: Variable) -> None: + """Addition fills const with 0 (additive identity) then adds.""" + expr = (1 * v).shift(dim_2=1) + result = expr + 5 + assert not np.isnan(result.const.values[0]) + assert result.const.values[0] == 5.0 + + @pytest.mark.legacy_only + def test_shifted_expr_mul_scalar_filled(self, v: Variable) -> None: expr = (1 * v).shift(dim_2=1) result = expr * 2 assert not np.isnan(result.coeffs.squeeze().values).any() assert result.coeffs.squeeze().values[0] == 0.0 + @pytest.mark.v1_only + def test_shifted_expr_mul_scalar_propagates(self, v: Variable) -> None: + expr = (1 * v).shift(dim_2=1) + result = expr * 2 + assert np.isnan(result.coeffs.squeeze().values[0]) + + @pytest.mark.legacy_only def test_shifted_expr_add_array(self, v: Variable) -> None: arr = np.arange(v.sizes["dim_2"], dtype=float) expr = (1 * v).shift(dim_2=1) @@ -940,6 +1336,16 @@ def test_shifted_expr_add_array(self, v: Variable) -> None: assert not np.isnan(result.const.values).any() assert result.const.values[0] == 0.0 + @pytest.mark.v1_only + def test_shifted_expr_add_array_revives(self, v: Variable) -> None: + """Addition fills const with 0 (additive identity) then adds.""" + arr = np.arange(v.sizes["dim_2"], dtype=float) + expr = (1 * v).shift(dim_2=1) + result = expr + arr + assert not np.isnan(result.const.values[0]) + assert result.const.values[0] == 0.0 + + @pytest.mark.legacy_only def test_shifted_expr_mul_array(self, v: Variable) -> None: arr = np.arange(v.sizes["dim_2"], dtype=float) + 1 expr = (1 * v).shift(dim_2=1) @@ -947,18 +1353,42 @@ def test_shifted_expr_mul_array(self, v: Variable) -> None: assert not np.isnan(result.coeffs.squeeze().values).any() assert result.coeffs.squeeze().values[0] == 0.0 + @pytest.mark.v1_only + def test_shifted_expr_mul_array_propagates(self, v: Variable) -> None: + arr = np.arange(v.sizes["dim_2"], dtype=float) + 1 + expr = (1 * v).shift(dim_2=1) + result = expr * arr + assert np.isnan(result.coeffs.squeeze().values[0]) + + @pytest.mark.legacy_only def test_shifted_expr_div_scalar(self, v: Variable) -> None: expr = (1 * v).shift(dim_2=1) result = expr / 2 assert not np.isnan(result.coeffs.squeeze().values).any() assert result.coeffs.squeeze().values[0] == 0.0 + @pytest.mark.v1_only + def test_shifted_expr_div_scalar_propagates(self, v: Variable) -> None: + expr = (1 * v).shift(dim_2=1) + result = expr / 2 + assert np.isnan(result.coeffs.squeeze().values[0]) + + @pytest.mark.legacy_only def test_shifted_expr_sub_scalar(self, v: Variable) -> None: expr = (1 * v).shift(dim_2=1) result = expr - 3 assert not np.isnan(result.const.values).any() assert result.const.values[0] == -3.0 + @pytest.mark.v1_only + def test_shifted_expr_sub_scalar_revives(self, v: Variable) -> None: + """Subtraction fills const with 0 (additive identity) then subtracts.""" + expr = (1 * v).shift(dim_2=1) + result = expr - 3 + assert not np.isnan(result.const.values[0]) + assert result.const.values[0] == -3.0 + + @pytest.mark.legacy_only def test_shifted_expr_div_array(self, v: Variable) -> None: arr = np.arange(v.sizes["dim_2"], dtype=float) + 1 expr = (1 * v).shift(dim_2=1) @@ -966,7 +1396,15 @@ def test_shifted_expr_div_array(self, v: Variable) -> None: assert not np.isnan(result.coeffs.squeeze().values).any() assert result.coeffs.squeeze().values[0] == 0.0 + @pytest.mark.v1_only + def test_shifted_expr_div_array_propagates(self, v: Variable) -> None: + arr = np.arange(v.sizes["dim_2"], dtype=float) + 1 + expr = (1 * v).shift(dim_2=1) + result = expr / arr + assert np.isnan(result.coeffs.squeeze().values[0]) + def test_variable_to_linexpr_nan_coefficient(self, v: Variable) -> None: + """to_linexpr fills NaN with 0 under both conventions (internal conversion).""" nan_coeff = np.ones(v.sizes["dim_2"]) nan_coeff[0] = np.nan result = v.to_linexpr(nan_coeff) @@ -974,7 +1412,8 @@ def test_variable_to_linexpr_nan_coefficient(self, v: Variable) -> None: assert result.coeffs.squeeze().values[0] == 0.0 class TestMultiDim: - def test_multidim_subset_mul(self, m: Model) -> None: + @pytest.mark.legacy_only + def test_multidim_subset_mul_fills(self, m: Model) -> None: coords_a = pd.RangeIndex(4, name="a") coords_b = pd.RangeIndex(5, name="b") w = m.add_variables(coords=[coords_a, coords_b], name="w") @@ -993,7 +1432,21 @@ def test_multidim_subset_mul(self, m: Model) -> None: assert result.coeffs.squeeze().sel(a=0, b=0).item() == pytest.approx(0.0) assert result.coeffs.squeeze().sel(a=1, b=2).item() == pytest.approx(0.0) - def test_multidim_subset_add(self, m: Model) -> None: + @pytest.mark.v1_only + def test_multidim_subset_mul_raises(self, m: Model) -> None: + coords_a = pd.RangeIndex(4, name="a") + coords_b = pd.RangeIndex(5, name="b") + w = m.add_variables(coords=[coords_a, coords_b], name="w") + subset_2d = xr.DataArray( + [[2.0, 3.0], [4.0, 5.0]], + dims=["a", "b"], + coords={"a": [1, 3], "b": [0, 4]}, + ) + with pytest.raises(ValueError, match="exact"): + w * subset_2d + + @pytest.mark.legacy_only + def test_multidim_subset_add_fills(self, m: Model) -> None: coords_a = pd.RangeIndex(4, name="a") coords_b = pd.RangeIndex(5, name="b") w = m.add_variables(coords=[coords_a, coords_b], name="w") @@ -1011,6 +1464,19 @@ def test_multidim_subset_add(self, m: Model) -> None: assert result.const.sel(a=3, b=4).item() == pytest.approx(5.0) assert result.const.sel(a=0, b=0).item() == pytest.approx(0.0) + @pytest.mark.v1_only + def test_multidim_subset_add_raises(self, m: Model) -> None: + coords_a = pd.RangeIndex(4, name="a") + coords_b = pd.RangeIndex(5, name="b") + w = m.add_variables(coords=[coords_a, coords_b], name="w") + subset_2d = xr.DataArray( + [[2.0, 3.0], [4.0, 5.0]], + dims=["a", "b"], + coords={"a": [1, 3], "b": [0, 4]}, + ) + with pytest.raises(ValueError, match="exact"): + w + subset_2d + class TestXarrayCompat: def test_da_eq_da_still_works(self) -> None: da1 = xr.DataArray([1, 2, 3]) @@ -1877,13 +2343,25 @@ def c(self, m2: Model) -> Variable: return m2.variables["c"] class TestAddition: + @pytest.mark.legacy_only def test_add_join_none_preserves_default( self, a: Variable, b: Variable ) -> None: + """Legacy: join=None uses outer join for mismatched coords.""" result_default = a.to_linexpr() + b.to_linexpr() result_none = a.to_linexpr().add(b.to_linexpr(), join=None) assert_linequal(result_default, result_none) + @pytest.mark.v1_only + def test_add_join_none_raises_on_mismatch_v1( + self, a: Variable, b: Variable + ) -> None: + """V1: join=None uses exact join, raises on mismatched coords.""" + with pytest.raises(ValueError, match="Coordinate mismatch"): + a.to_linexpr() + b.to_linexpr() + with pytest.raises(ValueError, match="Coordinate mismatch"): + a.to_linexpr().add(b.to_linexpr(), join=None) + def test_add_expr_join_inner(self, a: Variable, b: Variable) -> None: result = a.to_linexpr().add(b.to_linexpr(), join="inner") assert list(result.data.indexes["i"]) == [1, 2] @@ -1920,7 +2398,8 @@ def test_add_constant_join_override(self, a: Variable, c: Variable) -> None: def test_add_same_coords_all_joins(self, a: Variable, c: Variable) -> None: expr_a = 1 * a + 5 const = xr.DataArray([1, 2, 3], dims=["i"], coords={"i": [0, 1, 2]}) - for join in ["override", "outer", "inner"]: + joins: list[JoinOptions] = ["override", "outer", "inner"] + for join in joins: result = expr_a.add(const, join=join) assert list(result.coords["i"].values) == [0, 1, 2] np.testing.assert_array_equal(result.const.values, [6, 7, 8]) @@ -2137,24 +2616,48 @@ def test_div_constant_outer_fill_values(self, a: Variable) -> None: assert result.coeffs.squeeze().sel(i=0).item() == pytest.approx(1.0) class TestQuadratic: + @pytest.mark.legacy_only def test_quadratic_add_constant_join_inner( self, a: Variable, b: Variable ) -> None: + """Legacy: a*b with mismatched coords uses outer join.""" quad = a.to_linexpr() * b.to_linexpr() const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 2, 3]}) result = quad.add(const, join="inner") assert list(result.data.indexes["i"]) == [1, 2, 3] + @pytest.mark.v1_only + def test_quadratic_add_constant_join_inner_v1( + self, a: Variable, c: Variable + ) -> None: + """V1: use a*c (same coords) to create quad, then join inner.""" + quad = a.to_linexpr() * c.to_linexpr() + const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 2, 3]}) + result = quad.add(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + def test_quadratic_add_expr_join_inner(self, a: Variable) -> None: quad = a.to_linexpr() * a.to_linexpr() const = xr.DataArray([10, 20], dims=["i"], coords={"i": [0, 1]}) result = quad.add(const, join="inner") assert list(result.data.indexes["i"]) == [0, 1] + @pytest.mark.legacy_only def test_quadratic_mul_constant_join_inner( self, a: Variable, b: Variable ) -> None: + """Legacy: a*b with mismatched coords uses outer join.""" quad = a.to_linexpr() * b.to_linexpr() const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) result = quad.mul(const, join="inner") assert list(result.data.indexes["i"]) == [1, 2, 3] + + @pytest.mark.v1_only + def test_quadratic_mul_constant_join_inner_v1( + self, a: Variable, c: Variable + ) -> None: + """V1: use a*c (same coords) to create quad, then join inner.""" + quad = a.to_linexpr() * c.to_linexpr() + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = quad.mul(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] diff --git a/test/test_optimization.py b/test/test_optimization.py index cdac8e61..4696f6c2 100644 --- a/test/test_optimization.py +++ b/test/test_optimization.py @@ -186,8 +186,8 @@ def model_with_non_aligned_variables() -> Model: lower = pd.Series(0, range(8)) y = m.add_variables(lower=lower, coords=[lower.index], name="y") - m.add_constraints(x + y, GREATER_EQUAL, 10.5) - m.objective = 1 * x + 0.5 * y + m.add_constraints(x.add(y, join="outer"), GREATER_EQUAL, 10.5) + m.objective = x.add(0.5 * y, join="outer") return m @@ -1108,14 +1108,15 @@ def auto_mask_variable_model() -> Model: @pytest.fixture def auto_mask_constraint_model() -> Model: - """Model with auto_mask=True and NaN in constraint RHS.""" - m = Model(auto_mask=True) + """Model with NaN in constraint RHS, masked explicitly.""" + m = Model() x = m.add_variables(lower=0, coords=[range(10)], name="x") y = m.add_variables(lower=0, coords=[range(10)], name="y") rhs = pd.Series([10.0] * 8 + [np.nan, np.nan], range(10)) - m.add_constraints(x + y, GREATER_EQUAL, rhs) # NaN rhs auto-masked + mask = rhs.notnull() + m.add_constraints(x + y, GREATER_EQUAL, rhs.fillna(0), mask=mask) m.add_constraints(x + y, GREATER_EQUAL, 5) m.add_objective(2 * x + y) diff --git a/test/test_repr.py b/test/test_repr.py index 9a7af893..e2782d41 100644 --- a/test/test_repr.py +++ b/test/test_repr.py @@ -1,15 +1,20 @@ from __future__ import annotations +import warnings + import numpy as np import pandas as pd import pytest import xarray as xr from linopy import Model, options +from linopy.config import LinopyDeprecationWarning from linopy.constraints import Constraint from linopy.expressions import LinearExpression from linopy.variables import Variable +warnings.filterwarnings("ignore", category=LinopyDeprecationWarning) + m = Model() lower = pd.Series(0, range(10))