diff --git a/doc/coordinate-alignment.nblink b/doc/coordinate-alignment.nblink new file mode 100644 index 00000000..ef588b91 --- /dev/null +++ b/doc/coordinate-alignment.nblink @@ -0,0 +1,3 @@ +{ + "path": "../examples/coordinate-alignment.ipynb" +} diff --git a/examples/coordinate-alignment.ipynb b/examples/coordinate-alignment.ipynb index 1547bd9d..e1309e37 100644 --- a/examples/coordinate-alignment.ipynb +++ b/examples/coordinate-alignment.ipynb @@ -4,469 +4,479 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Coordinate Alignment\n", + "# Coordinate Alignment in linopy\n", "\n", - "Since linopy builds on xarray, coordinate alignment matters when combining variables or expressions that live on different coordinates. By default, linopy aligns operands automatically and fills missing entries with sensible defaults. This guide shows how alignment works and how to control it with the ``join`` parameter." + "linopy enforces strict defaults for coordinate alignment so that mismatches never silently produce wrong results.\n", + "\n", + "| Operation | Shared-dim alignment | Extra dims on constant/RHS |\n", + "|-----------|---------------------|---------------------------|\n", + "| `+`, `-` | `\"exact\"` — must match | **Forbidden** |\n", + "| `*`, `/` | `\"inner\"` — intersection | Expands the expression |\n", + "| `<=`, `>=`, `==` | `\"exact\"` — must match | **Forbidden** |\n", + "\n", + "**Why?** Addition and constraint RHS only change constant terms — expanding into new dimensions would duplicate the same variable. Multiplication changes coefficients, so expanding is meaningful. The rules are consistent: `a*x + b <= 0` and `a*x <= -b` always behave identically.\n", + "\n", + "When coordinates don't match, use the named methods (`.add()`, `.sub()`, `.mul()`, `.div()`, `.le()`, `.ge()`, `.eq()`) with an explicit `join=` parameter.\n", + "\n", + "Inspired by [pyoframe](https://github.com/Bravos-Power/pyoframe)." ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:53.150316Z", + "iopub.status.busy": "2026-02-20T12:35:53.150100Z", + "iopub.status.idle": "2026-02-20T12:35:54.105967Z", + "shell.execute_reply": "2026-02-20T12:35:54.105432Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.193551Z", + "start_time": "2026-02-20T12:36:56.190913Z" + } + }, "source": [ "import numpy as np\n", "import pandas as pd\n", "import xarray as xr\n", "\n", "import linopy" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Default Alignment Behavior\n", - "\n", - "When two operands share a dimension but have different coordinates, linopy keeps the **larger** (superset) coordinate range and fills missing positions with zeros (for addition) or zero coefficients (for multiplication)." + "## What works by default" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.110532Z", + "iopub.status.busy": "2026-02-20T12:35:54.109029Z", + "iopub.status.idle": "2026-02-20T12:35:54.164335Z", + "shell.execute_reply": "2026-02-20T12:35:54.163789Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.215580Z", + "start_time": "2026-02-20T12:36:56.207497Z" + } + }, "source": [ "m = linopy.Model()\n", "\n", "time = pd.RangeIndex(5, name=\"time\")\n", - "x = m.add_variables(lower=0, coords=[time], name=\"x\")\n", + "techs = pd.Index([\"solar\", \"wind\", \"gas\"], name=\"tech\")\n", "\n", - "subset_time = pd.RangeIndex(3, name=\"time\")\n", - "y = m.add_variables(lower=0, coords=[subset_time], name=\"y\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Adding ``x`` (5 time steps) and ``y`` (3 time steps) gives an expression over all 5 time steps. Where ``y`` has no entry (time 3, 4), the coefficient is zero — i.e. ``y`` simply drops out of the sum at those positions." - ] + "x = m.add_variables(lower=0, coords=[time], name=\"x\")\n", + "y = m.add_variables(lower=0, coords=[time], name=\"y\")\n", + "gen = m.add_variables(lower=0, coords=[time, techs], name=\"gen\")" + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.166957Z", + "iopub.status.busy": "2026-02-20T12:35:54.166600Z", + "iopub.status.idle": "2026-02-20T12:35:54.185234Z", + "shell.execute_reply": "2026-02-20T12:35:54.184778Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.230513Z", + "start_time": "2026-02-20T12:36:56.222101Z" + } + }, "source": [ + "# Addition/subtraction — matching coordinates\n", "x + y" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The same applies when multiplying by a constant that covers only a subset of coordinates. Missing positions get a coefficient of zero:" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.187479Z", + "iopub.status.busy": "2026-02-20T12:35:54.187284Z", + "iopub.status.idle": "2026-02-20T12:35:54.197488Z", + "shell.execute_reply": "2026-02-20T12:35:54.197090Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.241644Z", + "start_time": "2026-02-20T12:36:56.235473Z" + } + }, "source": [ - "factor = xr.DataArray([2, 3, 4], dims=[\"time\"], coords={\"time\": [0, 1, 2]})\n", + "# Multiplication — matching coordinates\n", + "factor = xr.DataArray([2, 3, 4, 5, 6], dims=[\"time\"], coords={\"time\": time})\n", "x * factor" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Adding a constant subset also fills missing coordinates with zero:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, + ], "outputs": [], - "source": [ - "x + factor" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Constraints with Subset RHS\n", - "\n", - "For constraints, missing right-hand-side values are filled with ``NaN``, which tells linopy to **skip** the constraint at those positions:" - ] + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rhs = xr.DataArray([10, 20, 30], dims=[\"time\"], coords={\"time\": [0, 1, 2]})\n", - "con = x <= rhs\n", - "con" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.199528Z", + "iopub.status.busy": "2026-02-20T12:35:54.199323Z", + "iopub.status.idle": "2026-02-20T12:35:54.210352Z", + "shell.execute_reply": "2026-02-20T12:35:54.209978Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.253971Z", + "start_time": "2026-02-20T12:36:56.246880Z" + } + }, "source": [ - "The constraint only applies at time 0, 1, 2. At time 3 and 4 the RHS is ``NaN``, so no constraint is created." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "### Same-Shape Operands: Positional Alignment\n\nWhen two operands have the **same shape** on a shared dimension, linopy uses **positional alignment** by default — coordinate labels are ignored and the left operand's labels are kept. This is a performance optimization but can be surprising:" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, + "# Multiplication — partial overlap gives intersection\n", + "partial = xr.DataArray([10, 20, 30], dims=[\"time\"], coords={\"time\": [0, 1, 2]})\n", + "x * partial # result: time 0, 1, 2 only" + ], "outputs": [], - "source": [ - "offset_const = xr.DataArray(\n", - " [10, 20, 30, 40, 50], dims=[\"time\"], coords={\"time\": [5, 6, 7, 8, 9]}\n", - ")\n", - "x + offset_const" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "Even though ``offset_const`` has coordinates ``[5, 6, 7, 8, 9]`` and ``x`` has ``[0, 1, 2, 3, 4]``, the result uses ``x``'s labels. The values are aligned by **position**, not by label. The same applies when adding two variables or expressions of identical shape:" + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.212115Z", + "iopub.status.busy": "2026-02-20T12:35:54.211953Z", + "iopub.status.idle": "2026-02-20T12:35:54.223732Z", + "shell.execute_reply": "2026-02-20T12:35:54.223319Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.267382Z", + "start_time": "2026-02-20T12:36:56.259835Z" + } + }, "source": [ - "z = m.add_variables(lower=0, coords=[pd.RangeIndex(5, 10, name=\"time\")], name=\"z\")\n", - "x + z" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "``x`` (time 0–4) and ``z`` (time 5–9) share no coordinate labels, yet the result has 5 entries under ``x``'s coordinates — because they have the same shape, positions are matched directly.\n\nTo force **label-based** alignment, pass an explicit ``join``:" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, + "# Multiplication — different dims broadcast (expands the expression)\n", + "cost = xr.DataArray([1.0, 0.5, 3.0], dims=[\"tech\"], coords={\"tech\": techs})\n", + "x * cost # result: (time, tech)" + ], "outputs": [], - "source": [ - "x.add(z, join=\"outer\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "With ``join=\"outer\"``, the result spans all 10 time steps (union of 0–4 and 5–9), filling missing positions with zeros. This is the correct label-based alignment. The same-shape positional shortcut is equivalent to ``join=\"override\"`` — see below." - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## The ``join`` Parameter\n", - "\n", - "For explicit control over alignment, use the ``.add()``, ``.sub()``, ``.mul()``, and ``.div()`` methods with a ``join`` parameter. The supported values follow xarray conventions:\n", - "\n", - "- ``\"inner\"`` — intersection of coordinates\n", - "- ``\"outer\"`` — union of coordinates (with fill)\n", - "- ``\"left\"`` — keep left operand's coordinates\n", - "- ``\"right\"`` — keep right operand's coordinates\n", - "- ``\"override\"`` — positional alignment, ignore coordinate labels\n", - "- ``\"exact\"`` — coordinates must match exactly (raises on mismatch)" - ] + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "m2 = linopy.Model()\n", - "\n", - "i_a = pd.Index([0, 1, 2], name=\"i\")\n", - "i_b = pd.Index([1, 2, 3], name=\"i\")\n", - "\n", - "a = m2.add_variables(coords=[i_a], name=\"a\")\n", - "b = m2.add_variables(coords=[i_b], name=\"b\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.225717Z", + "iopub.status.busy": "2026-02-20T12:35:54.225519Z", + "iopub.status.idle": "2026-02-20T12:35:54.247553Z", + "shell.execute_reply": "2026-02-20T12:35:54.247125Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.305476Z", + "start_time": "2026-02-20T12:36:56.292Z" + } + }, "source": [ - "**Inner join** — only shared coordinates (i=1, 2):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, + "# Constraints — RHS with fewer dims broadcasts naturally\n", + "capacity = xr.DataArray([100, 80, 50], dims=[\"tech\"], coords={\"tech\": techs})\n", + "m.add_constraints(gen <= capacity, name=\"cap\") # capacity broadcasts over time" + ], "outputs": [], - "source": [ - "a.add(b, join=\"inner\")" - ] + "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ - "**Outer join** — union of coordinates (i=0, 1, 2, 3):" + "## What raises an error" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a.add(b, join=\"outer\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.249529Z", + "iopub.status.busy": "2026-02-20T12:35:54.249355Z", + "iopub.status.idle": "2026-02-20T12:35:54.260588Z", + "shell.execute_reply": "2026-02-20T12:35:54.259868Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.319773Z", + "start_time": "2026-02-20T12:36:56.312636Z" + } + }, "source": [ - "**Left join** — keep left operand's coordinates (i=0, 1, 2):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, + "# Addition with mismatched coordinates\n", + "y_short = m.add_variables(\n", + " lower=0, coords=[pd.RangeIndex(3, name=\"time\")], name=\"y_short\"\n", + ")\n", + "\n", + "try:\n", + " x + y_short # time coords don't match\n", + "except ValueError as e:\n", + " print(\"ValueError:\", e)" + ], "outputs": [], - "source": [ - "a.add(b, join=\"left\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Right join** — keep right operand's coordinates (i=1, 2, 3):" - ] + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.262548Z", + "iopub.status.busy": "2026-02-20T12:35:54.262376Z", + "iopub.status.idle": "2026-02-20T12:35:54.268753Z", + "shell.execute_reply": "2026-02-20T12:35:54.268391Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.331386Z", + "start_time": "2026-02-20T12:36:56.326247Z" + } + }, "source": [ - "a.add(b, join=\"right\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "**Override** — positional alignment, ignore coordinate labels. The result uses the left operand's coordinates. Here ``a`` has i=[0, 1, 2] and ``b`` has i=[1, 2, 3], so positions are matched as 0↔1, 1↔2, 2↔3:" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, + "# Addition with extra dimensions on the constant\n", + "profile = xr.DataArray(\n", + " np.ones((3, 5)), dims=[\"tech\", \"time\"], coords={\"tech\": techs, \"time\": time}\n", + ")\n", + "try:\n", + " x + profile # would duplicate x[t] across techs\n", + "except ValueError as e:\n", + " print(\"ValueError:\", e)" + ], "outputs": [], - "source": [ - "a.add(b, join=\"override\")" - ] + "execution_count": null }, { - "cell_type": "markdown", - "metadata": {}, + "cell_type": "code", + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.270585Z", + "iopub.status.busy": "2026-02-20T12:35:54.270420Z", + "iopub.status.idle": "2026-02-20T12:35:54.277993Z", + "shell.execute_reply": "2026-02-20T12:35:54.276363Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.350503Z", + "start_time": "2026-02-20T12:36:56.343806Z" + } + }, "source": [ - "### Multiplication with ``join``\n", + "# Multiplication with zero overlap\n", + "z = m.add_variables(lower=0, coords=[pd.RangeIndex(5, 10, name=\"time\")], name=\"z\")\n", "\n", - "The same ``join`` parameter works on ``.mul()`` and ``.div()``. When multiplying by a constant that covers a subset, ``join=\"inner\"`` restricts the result to shared coordinates only, while ``join=\"left\"`` fills missing values with zero:" - ] + "try:\n", + " z * factor # z has time 5-9, factor has time 0-4 — no intersection\n", + "except ValueError as e:\n", + " print(\"ValueError:\", e)" + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.281858Z", + "iopub.status.busy": "2026-02-20T12:35:54.281316Z", + "iopub.status.idle": "2026-02-20T12:35:54.287843Z", + "shell.execute_reply": "2026-02-20T12:35:54.287269Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.361211Z", + "start_time": "2026-02-20T12:36:56.356813Z" + } + }, "source": [ - "const = xr.DataArray([2, 3, 4], dims=[\"i\"], coords={\"i\": [1, 2, 3]})\n", + "# Constraint RHS with mismatched coordinates\n", + "partial_rhs = xr.DataArray([10, 20, 30], dims=[\"time\"], coords={\"time\": [0, 1, 2]})\n", "\n", - "a.mul(const, join=\"inner\")" - ] + "try:\n", + " x <= partial_rhs\n", + "except ValueError as e:\n", + " print(\"ValueError:\", e)" + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.290439Z", + "iopub.status.busy": "2026-02-20T12:35:54.290235Z", + "iopub.status.idle": "2026-02-20T12:35:54.302535Z", + "shell.execute_reply": "2026-02-20T12:35:54.302145Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.385743Z", + "start_time": "2026-02-20T12:36:56.380702Z" + } + }, "source": [ - "a.mul(const, join=\"left\")" - ] + "# Constraint RHS with extra dimensions\n", + "w = m.add_variables(lower=0, coords=[techs], name=\"w\") # dims: (tech,)\n", + "rhs_2d = xr.DataArray(\n", + " np.ones((5, 3)), dims=[\"time\", \"tech\"], coords={\"time\": time, \"tech\": techs}\n", + ")\n", + "try:\n", + " w <= rhs_2d # would create redundant constraints on w[tech]\n", + "except ValueError as e:\n", + " print(\"ValueError:\", e)" + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Alignment in Constraints\n", + "## Positional alignment\n", "\n", - "The ``.le()``, ``.ge()``, and ``.eq()`` methods create constraints with explicit coordinate alignment. They accept the same ``join`` parameter:" + "A common pattern: two arrays with the same shape but different (or no) coordinate labels. The cleanest fix is to relabel one operand with `.assign_coords()` so that coordinates match explicitly:" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.304505Z", + "iopub.status.busy": "2026-02-20T12:35:54.304317Z", + "iopub.status.idle": "2026-02-20T12:35:54.322551Z", + "shell.execute_reply": "2026-02-20T12:35:54.322153Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:37:36.671817Z", + "start_time": "2026-02-20T12:37:36.662325Z" + } + }, "source": [ - "rhs = xr.DataArray([10, 20], dims=[\"i\"], coords={\"i\": [0, 1]})\n", + "m2 = linopy.Model()\n", "\n", - "a.le(rhs, join=\"inner\")" - ] + "a = m2.add_variables(coords=[[\"x\", \"y\", \"z\"]], name=\"a\")\n", + "b = m2.add_variables(coords=[[\"p\", \"q\", \"r\"]], name=\"b\")\n", + "\n", + "# Relabel b's coordinates to match a, then add normally\n", + "a + b.assign_coords(dim_0=a.coords[\"dim_0\"])" + ], + "outputs": [], + "execution_count": null }, { - "cell_type": "markdown", - "metadata": {}, + "cell_type": "code", + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.324642Z", + "iopub.status.busy": "2026-02-20T12:35:54.324465Z", + "iopub.status.idle": "2026-02-20T12:35:54.332579Z", + "shell.execute_reply": "2026-02-20T12:35:54.332088Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.424015Z", + "start_time": "2026-02-20T12:36:56.418311Z" + } + }, "source": [ - "With ``join=\"inner\"``, the constraint only exists at the intersection (i=0, 1). Compare with ``join=\"left\"``:" - ] + "# Same for constraints\n", + "rhs = xr.DataArray([1.0, 2.0, 3.0], dims=[\"dim_0\"], coords={\"dim_0\": [\"p\", \"q\", \"r\"]})\n", + "a <= rhs.assign_coords(dim_0=a.coords[\"dim_0\"])" + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.336196Z", + "iopub.status.busy": "2026-02-20T12:35:54.335947Z", + "iopub.status.idle": "2026-02-20T12:35:54.360683Z", + "shell.execute_reply": "2026-02-20T12:35:54.359622Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.441516Z", + "start_time": "2026-02-20T12:36:56.432774Z" + } + }, "source": [ - "a.le(rhs, join=\"left\")" - ] + "# Shorthand: join=\"override\" does the same (positional match, keeps left labels)\n", + "a.add(b, join=\"override\")" + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ - "With ``join=\"left\"``, the result covers all of ``a``'s coordinates (i=0, 1, 2). At i=2, where the RHS has no value, the RHS becomes ``NaN`` and the constraint is masked out.\n", + "## Other join modes\n", "\n", - "The same methods work on expressions:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "expr = 2 * a + 1\n", - "expr.eq(rhs, join=\"inner\")" + "All named methods (`.add()`, `.sub()`, `.mul()`, `.div()`, `.le()`, `.ge()`, `.eq()`) accept a `join=` parameter:\n", + "\n", + "| `join` | Coordinates kept | Fill |\n", + "|--------|-----------------|------|\n", + "| `\"exact\"` | Must match | `ValueError` if different |\n", + "| `\"inner\"` | Intersection | — |\n", + "| `\"outer\"` | Union | Zero (arithmetic) / NaN (constraints) |\n", + "| `\"left\"` | Left operand's | Zero / NaN for missing right |\n", + "| `\"right\"` | Right operand's | Zero for missing left |\n", + "| `\"override\"` | Left operand's | Positional alignment |" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "## Practical Example\n\nConsider a generation dispatch model where solar availability follows a daily profile and a minimum demand constraint only applies during peak hours." - }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "execution": { + "iopub.execute_input": "2026-02-20T12:35:54.363885Z", + "iopub.status.busy": "2026-02-20T12:35:54.363642Z", + "iopub.status.idle": "2026-02-20T12:35:54.404550Z", + "shell.execute_reply": "2026-02-20T12:35:54.403860Z" + }, + "ExecuteTime": { + "end_time": "2026-02-20T12:36:56.472328Z", + "start_time": "2026-02-20T12:36:56.446352Z" + } + }, "source": [ - "m3 = linopy.Model()\n", + "i_a = pd.Index([0, 1, 2], name=\"i\")\n", + "i_b = pd.Index([1, 2, 3], name=\"i\")\n", "\n", - "hours = pd.RangeIndex(24, name=\"hour\")\n", - "techs = pd.Index([\"solar\", \"wind\", \"gas\"], name=\"tech\")\n", + "a = m2.add_variables(coords=[i_a], name=\"a2\")\n", + "b = m2.add_variables(coords=[i_b], name=\"b2\")\n", "\n", - "gen = m3.add_variables(lower=0, coords=[hours, techs], name=\"gen\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Capacity limits apply to all hours and techs — standard broadcasting handles this:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, + "print(\"inner:\", list(a.add(b, join=\"inner\").coords[\"i\"].values)) # [1, 2]\n", + "print(\"outer:\", list(a.add(b, join=\"outer\").coords[\"i\"].values)) # [0, 1, 2, 3]\n", + "print(\"left: \", list(a.add(b, join=\"left\").coords[\"i\"].values)) # [0, 1, 2]\n", + "print(\"right:\", list(a.add(b, join=\"right\").coords[\"i\"].values)) # [1, 2, 3]" + ], "outputs": [], - "source": [ - "capacity = xr.DataArray([100, 80, 50], dims=[\"tech\"], coords={\"tech\": techs})\n", - "m3.add_constraints(gen <= capacity, name=\"capacity_limit\")" - ] + "execution_count": null }, { "cell_type": "markdown", "metadata": {}, - "source": "For solar, we build a full 24-hour availability profile — zero at night, sine-shaped during daylight (hours 6–18). Since this covers all hours, standard alignment works directly and solar is properly constrained to zero at night:" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "solar_avail = np.zeros(24)\n", - "solar_avail[6:19] = 100 * np.sin(np.linspace(0, np.pi, 13))\n", - "solar_availability = xr.DataArray(solar_avail, dims=[\"hour\"], coords={\"hour\": hours})\n", + "## Migrating from previous versions\n", "\n", - "solar_gen = gen.sel(tech=\"solar\")\n", - "m3.add_constraints(solar_gen <= solar_availability, name=\"solar_avail\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "Now suppose a minimum demand of 120 MW must be met, but only during peak hours (8–20). The demand array covers a subset of hours, so we use ``join=\"inner\"`` to restrict the constraint to just those hours:" - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "peak_hours = pd.RangeIndex(8, 21, name=\"hour\")\n", - "peak_demand = xr.DataArray(\n", - " np.full(len(peak_hours), 120.0), dims=[\"hour\"], coords={\"hour\": peak_hours}\n", - ")\n", + "Previous versions used a shape-dependent heuristic that caused silent bugs (positional alignment on same-shape operands, non-associative addition, broken multiplication). The new behavior:\n", "\n", - "total_gen = gen.sum(\"tech\")\n", - "m3.add_constraints(total_gen.ge(peak_demand, join=\"inner\"), name=\"peak_demand\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "The demand constraint only applies during peak hours (8–20). Outside that range, no minimum generation is required." - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Summary\n", + "| Condition | Old | New |\n", + "|-----------|-----|-----|\n", + "| Same shape, different coords, `+`/`-` | Positional match (silent bug) | `ValueError` |\n", + "| Different shape, `+`/`-` | `\"outer\"` or `\"left\"` (implicit) | `ValueError` |\n", + "| Mismatched coords, `*`/`/` | Crash or garbage | Intersection (or error if empty) |\n", + "| Constraint with mismatched RHS | `\"override\"` or `\"left\"` | `ValueError` |\n", "\n", - "| ``join`` | Coordinates | Fill behavior |\n", - "|----------|------------|---------------|\n", - "| ``None`` (default) | Auto-detect (keeps superset) | Zeros for arithmetic, NaN for constraint RHS |\n", - "| ``\"inner\"`` | Intersection only | No fill needed |\n", - "| ``\"outer\"`` | Union | Fill with operation identity (0 for add, 0 for mul) |\n", - "| ``\"left\"`` | Left operand's | Fill right with identity |\n", - "| ``\"right\"`` | Right operand's | Fill left with identity |\n", - "| ``\"override\"`` | Left operand's (positional) | Positional alignment, ignore labels |\n", - "| ``\"exact\"`` | Must match exactly | Raises error if different |" + "To migrate: replace `x + y` with `x.add(y, join=\"outer\")` (or whichever join matches your intent)." ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -480,9 +490,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.11.11" } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 5 } diff --git a/linopy/expressions.py b/linopy/expressions.py index e1fbe1a9..3a150d0d 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -48,7 +48,6 @@ LocIndexer, as_dataarray, assign_multiindex_safe, - check_common_keys_values, check_has_nulls, check_has_nulls_polars, fill_missing_coords, @@ -528,6 +527,7 @@ def _align_constant( other: DataArray, fill_value: float = 0, join: str | None = None, + default_join: str = "exact", ) -> tuple[DataArray, DataArray, bool]: """ Align a constant DataArray with self.const. @@ -539,7 +539,10 @@ def _align_constant( fill_value : float, default: 0 Fill value for missing coordinates. join : str, optional - Alignment method. If None, uses size-aware default behavior. + Alignment method. If None, uses default_join. + default_join : str, default: "exact" + Default join mode when join is None. Use "exact" for add/sub, + "inner" for mul/div. Returns ------- @@ -551,19 +554,32 @@ def _align_constant( Whether the expression's data needs reindexing. """ if join is None: - if other.sizes == self.const.sizes: - return self.const, other.assign_coords(coords=self.coords), False + join = default_join + + if join == "override": + return self.const, other.assign_coords(coords=self.coords), False + elif join == "left": return ( self.const, other.reindex_like(self.const, fill_value=fill_value), False, ) - elif join == "override": - return self.const, other.assign_coords(coords=self.coords), False else: - self_const, aligned = xr.align( - self.const, other, join=join, fill_value=fill_value - ) + try: + self_const, aligned = xr.align( + self.const, other, join=join, fill_value=fill_value + ) + except ValueError as e: + if "exact" in str(e): + raise ValueError( + f"{e}\n" + "Use .add()/.sub()/.mul()/.div() with an explicit join= parameter:\n" + ' .add(other, join="inner") # intersection of coordinates\n' + ' .add(other, join="outer") # union of coordinates (with fill)\n' + ' .add(other, join="left") # keep left operand\'s coordinates\n' + ' .add(other, join="override") # positional alignment' + ) from None + raise return self_const, aligned, True def _add_constant( @@ -572,8 +588,16 @@ def _add_constant( if np.isscalar(other) and join is None: return self.assign(const=self.const + other) da = as_dataarray(other, coords=self.coords, dims=self.coord_dims) + extra_dims = set(da.dims) - set(self.coord_dims) + if extra_dims: + raise ValueError( + f"Constant has dimensions {extra_dims} not present in the " + f"expression. Addition/subtraction cannot introduce new " + f"dimensions — use multiplication to expand, or select/reindex " + f"the constant to match the expression's dimensions." + ) self_const, da, needs_data_reindex = self._align_constant( - da, fill_value=0, join=join + da, fill_value=0, join=join, default_join="exact" ) if needs_data_reindex: return self.__class__( @@ -593,8 +617,13 @@ def _apply_constant_op( ) -> GenericExpression: factor = as_dataarray(other, coords=self.coords, dims=self.coord_dims) self_const, factor, needs_data_reindex = self._align_constant( - factor, fill_value=fill_value, join=join + factor, fill_value=fill_value, join=join, default_join="inner" ) + if self_const.size == 0 and self.const.size > 0: + raise ValueError( + "Multiplication/division resulted in an empty expression because " + "the operands have no overlapping coordinates (inner join)." + ) if needs_data_reindex: data = self.data.reindex_like(self_const, fill_value=self._fill_value) return self.__class__( @@ -1082,7 +1111,40 @@ def to_constraint( f"RHS DataArray has dimensions {extra_dims} not present " f"in the expression. Cannot create constraint." ) - rhs = rhs.reindex_like(self.const, fill_value=np.nan) + effective_join = join if join is not None else "exact" + if effective_join == "override": + aligned_rhs = rhs.assign_coords(coords=self.const.coords) + expr_const = self.const + expr_data = self.data + elif effective_join == "left": + aligned_rhs = rhs.reindex_like(self.const, fill_value=np.nan) + expr_const = self.const + expr_data = self.data + else: + try: + expr_const_aligned, aligned_rhs = xr.align( + self.const, rhs, join=effective_join, fill_value=np.nan + ) + except ValueError as e: + if "exact" in str(e): + raise ValueError( + f"{e}\n" + "Use .le()/.ge()/.eq() with an explicit join= parameter:\n" + ' .le(rhs, join="inner") # intersection of coordinates\n' + ' .le(rhs, join="left") # keep expression coordinates (NaN fill)\n' + ' .le(rhs, join="override") # positional alignment' + ) from None + raise + expr_const = expr_const_aligned.fillna(0) + expr_data = self.data.reindex_like( + expr_const_aligned, fill_value=self._fill_value + ) + aligned_rhs = aligned_rhs + constraint_rhs = aligned_rhs - expr_const + data = assign_multiindex_safe( + expr_data[["coeffs", "vars"]], sign=sign, rhs=constraint_rhs + ) + return constraints.Constraint(data, model=self.model) elif isinstance(rhs, np.ndarray | pd.Series | pd.DataFrame) and rhs.ndim > len( self.coord_dims ): @@ -2320,16 +2382,6 @@ def merge( model = exprs[0].model - if join is not None: - override = join == "override" - elif cls in linopy_types and dim in HELPER_DIMS: - coord_dims = [ - {k: v for k, v in e.sizes.items() if k not in HELPER_DIMS} for e in exprs - ] - override = check_common_keys_values(coord_dims) # type: ignore - else: - override = False - data = [e.data if isinstance(e, linopy_types) else e for e in exprs] data = [fill_missing_coords(ds, fill_helper_dims=True) for ds in data] @@ -2345,23 +2397,55 @@ def merge( if join is not None: kwargs["join"] = join - elif override: - kwargs["join"] = "override" + elif dim == TERM_DIM: + kwargs["join"] = "exact" + elif dim == FACTOR_DIM: + kwargs["join"] = "inner" else: - kwargs.setdefault("join", "outer") - - if dim == TERM_DIM: - ds = xr.concat([d[["coeffs", "vars"]] for d in data], dim, **kwargs) - subkwargs = {**kwargs, "fill_value": 0} - const = xr.concat([d["const"] for d in data], dim, **subkwargs).sum(TERM_DIM) - ds = assign_multiindex_safe(ds, const=const) - elif dim == FACTOR_DIM: - ds = xr.concat([d[["vars"]] for d in data], dim, **kwargs) - coeffs = xr.concat([d["coeffs"] for d in data], dim, **kwargs).prod(FACTOR_DIM) - const = xr.concat([d["const"] for d in data], dim, **kwargs).prod(FACTOR_DIM) - ds = assign_multiindex_safe(ds, coeffs=coeffs, const=const) - else: - ds = xr.concat(data, dim, **kwargs) + kwargs["join"] = "outer" + + try: + if dim == TERM_DIM: + ds = xr.concat([d[["coeffs", "vars"]] for d in data], dim, **kwargs) + subkwargs = {**kwargs, "fill_value": 0} + const = xr.concat([d["const"] for d in data], dim, **subkwargs).sum( + TERM_DIM + ) + ds = assign_multiindex_safe(ds, const=const) + elif dim == FACTOR_DIM: + ds = xr.concat([d[["vars"]] for d in data], dim, **kwargs) + coeffs = xr.concat([d["coeffs"] for d in data], dim, **kwargs).prod( + FACTOR_DIM + ) + const = xr.concat([d["const"] for d in data], dim, **kwargs).prod( + FACTOR_DIM + ) + ds = assign_multiindex_safe(ds, coeffs=coeffs, const=const) + else: + # Pre-pad helper dims to same size before concat + fill = kwargs.get("fill_value", FILL_VALUE) + for helper_dim in HELPER_DIMS: + sizes = [d.sizes.get(helper_dim, 0) for d in data] + max_size = max(sizes) if sizes else 0 + if max_size > 0 and min(sizes) < max_size: + data = [ + d.reindex({helper_dim: range(max_size)}, fill_value=fill) + if d.sizes.get(helper_dim, 0) < max_size + else d + for d in data + ] + ds = xr.concat(data, dim, **kwargs) + except ValueError as e: + if "exact" in str(e): + raise ValueError( + f"{e}\n" + "Use .add()/.sub()/.mul()/.div() with an explicit join= parameter:\n" + ' .add(other, join="inner") # intersection of coordinates\n' + ' .add(other, join="outer") # union of coordinates (with fill)\n' + ' .add(other, join="left") # keep left operand\'s coordinates\n' + ' .add(other, join="override") # positional alignment' + ) from None + raise for d in set(HELPER_DIMS) & set(ds.coords): ds = ds.reset_index(d, drop=True) diff --git a/linopy/variables.py b/linopy/variables.py index 0eea6634..274344a1 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -400,8 +400,9 @@ def __mul__(self, other: SideLike) -> ExpressionLike: try: if isinstance(other, Variable | ScalarVariable): return self.to_linexpr() * other - - return self.to_linexpr(other) + if isinstance(other, expressions.LinearExpression): + return self.to_linexpr() * other + return self.to_linexpr()._multiply_by_constant(other) except TypeError: return NotImplemented diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 2af1a8ea..2ced61a0 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -7,8 +7,6 @@ from __future__ import annotations -from typing import Any - import numpy as np import pandas as pd import polars as pl @@ -443,8 +441,12 @@ def test_linear_expression_sum( assert_linequal(expr.sum(["dim_0", TERM_DIM]), expr.sum("dim_0")) - # test special case otherride coords - expr = v.loc[:9] + v.loc[10:] + # disjoint coords now raise with exact default + with pytest.raises(ValueError, match="exact"): + v.loc[:9] + v.loc[10:] + + # positional alignment via assign_coords + expr = v.loc[:9] + v.loc[10:].assign_coords(dim_2=v.loc[:9].coords["dim_2"]) assert expr.nterm == 2 assert len(expr.coords["dim_2"]) == 10 @@ -467,8 +469,12 @@ def test_linear_expression_sum_with_const( assert_linequal(expr.sum(["dim_0", TERM_DIM]), expr.sum("dim_0")) - # test special case otherride coords - expr = v.loc[:9] + v.loc[10:] + # disjoint coords now raise with exact default + with pytest.raises(ValueError, match="exact"): + v.loc[:9] + v.loc[10:] + + # positional alignment via assign_coords + expr = v.loc[:9] + v.loc[10:].assign_coords(dim_2=v.loc[:9].coords["dim_2"]) assert expr.nterm == 2 assert len(expr.coords["dim_2"]) == 10 @@ -577,7 +583,18 @@ def test_linear_expression_multiplication_invalid( expr / x -class TestSubsetCoordinateAlignment: +class TestExactAlignmentDefault: + """ + Test the new alignment convention: exact for +/-, inner for *//. + + v has dim_2=[0..19] (20 entries). + subset has dim_2=[1, 3] (2 entries, subset of v's coords). + superset has dim_2=[0..24] (25 entries, superset of v's coords). + + Each test shows the operation, verifies the new behavior (raises or + intersection), then shows the explicit join= that recovers the old result. + """ + @pytest.fixture def subset(self) -> xr.DataArray: return xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) @@ -588,298 +605,293 @@ def superset(self) -> xr.DataArray: np.arange(25, dtype=float), dims=["dim_2"], coords={"dim_2": range(25)} ) + @pytest.fixture + def matching(self) -> xr.DataArray: + return xr.DataArray( + np.arange(20, dtype=float), + dims=["dim_2"], + coords={"dim_2": range(20)}, + ) + @pytest.fixture def expected_fill(self) -> np.ndarray: + """Old expected result: 20-entry array with values at positions 1,3.""" arr = np.zeros(20) arr[1] = 10.0 arr[3] = 30.0 return arr - def test_var_mul_subset( - self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray - ) -> None: - result = v * subset - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() - np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) - - def test_expr_mul_subset( - self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray - ) -> None: - expr = 1 * v - result = expr * subset - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() - np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) - - @pytest.mark.parametrize( - "make_lhs,make_rhs", - [ - (lambda v, s: s * v, lambda v, s: v * s), - (lambda v, s: s * (1 * v), lambda v, s: (1 * v) * s), - (lambda v, s: s + v, lambda v, s: v + s), - (lambda v, s: s + (v + 5), lambda v, s: (v + 5) + s), - ], - ids=["subset*var", "subset*expr", "subset+var", "subset+expr"], - ) - def test_commutativity( - self, v: Variable, subset: xr.DataArray, make_lhs: Any, make_rhs: Any - ) -> None: - assert_linequal(make_lhs(v, subset), make_rhs(v, subset)) + # --- Addition / subtraction with subset constant --- def test_var_add_subset( self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray ) -> None: - result = v + subset - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.const.values).any() + # now raises + with pytest.raises(ValueError, match="exact"): + v + subset + + # explicit join="left" recovers old behavior: 20 entries, fill 0 + result = v.add(subset, join="left") + assert result.sizes["dim_2"] == 20 np.testing.assert_array_equal(result.const.values, expected_fill) def test_var_sub_subset( self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray ) -> None: - result = v - subset - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.const.values).any() - np.testing.assert_array_equal(result.const.values, -expected_fill) + with pytest.raises(ValueError, match="exact"): + v - subset - def test_subset_sub_var(self, v: Variable, subset: xr.DataArray) -> None: - assert_linequal(subset - v, -v + subset) + result = v.sub(subset, join="left") + assert result.sizes["dim_2"] == 20 + np.testing.assert_array_equal(result.const.values, -expected_fill) def test_expr_add_subset( self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray ) -> None: - expr = v + 5 - result = expr + subset - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.const.values).any() + with pytest.raises(ValueError, match="exact"): + (v + 5) + subset + + result = (v + 5).add(subset, join="left") + assert result.sizes["dim_2"] == 20 np.testing.assert_array_equal(result.const.values, expected_fill + 5) - def test_expr_sub_subset( + # --- Addition with superset constant --- + + def test_var_add_superset(self, v: Variable, superset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + v + superset + + result = v.add(superset, join="left") + assert result.sizes["dim_2"] == 20 + assert not np.isnan(result.const.values).any() + + # --- Addition / multiplication with disjoint coords --- + + def test_disjoint_add(self, v: Variable) -> None: + disjoint = xr.DataArray( + [100.0, 200.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + with pytest.raises(ValueError, match="exact"): + v + disjoint + + result = v.add(disjoint, join="outer") + assert result.sizes["dim_2"] == 22 # union of [0..19] and [50, 60] + + def test_disjoint_mul(self, v: Variable) -> None: + disjoint = xr.DataArray( + [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + # inner join: no intersection → error + with pytest.raises(ValueError, match="no overlapping coordinates"): + v * disjoint + + # explicit join="left": 20 entries, all zeros + result = v.mul(disjoint, join="left") + assert result.sizes["dim_2"] == 20 + np.testing.assert_array_equal(result.coeffs.squeeze().values, np.zeros(20)) + + def test_disjoint_div(self, v: Variable) -> None: + disjoint = xr.DataArray( + [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + with pytest.raises(ValueError, match="no overlapping coordinates"): + v / disjoint + + # --- Multiplication / division with subset constant --- + + def test_var_mul_subset( self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray ) -> None: - expr = v + 5 - result = expr - subset - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.const.values).any() - np.testing.assert_array_equal(result.const.values, 5 - expected_fill) + # inner join: 2 entries (intersection) + result = v * subset + assert result.sizes["dim_2"] == 2 + assert result.coeffs.squeeze().sel(dim_2=1).item() == pytest.approx(10.0) + assert result.coeffs.squeeze().sel(dim_2=3).item() == pytest.approx(30.0) - def test_subset_sub_expr(self, v: Variable, subset: xr.DataArray) -> None: - expr = v + 5 - assert_linequal(subset - expr, -(expr - subset)) + # explicit join="left" recovers old behavior: 20 entries, fill 0 + result = v.mul(subset, join="left") + assert result.sizes["dim_2"] == 20 + np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + + def test_expr_mul_subset(self, v: Variable, subset: xr.DataArray) -> None: + result = (1 * v) * subset + assert result.sizes["dim_2"] == 2 + assert result.coeffs.squeeze().sel(dim_2=1).item() == pytest.approx(10.0) + + def test_var_mul_superset(self, v: Variable, superset: xr.DataArray) -> None: + # inner join: intersection = v's 20 coords + result = v * superset + assert result.sizes["dim_2"] == 20 + assert not np.isnan(result.coeffs.values).any() def test_var_div_subset(self, v: Variable, subset: xr.DataArray) -> None: + # inner join: 2 entries result = v / subset - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() + assert result.sizes["dim_2"] == 2 + assert result.coeffs.squeeze().sel(dim_2=1).item() == pytest.approx(0.1) + assert result.coeffs.squeeze().sel(dim_2=3).item() == pytest.approx(1.0 / 30) + + # explicit join="left": 20 entries, fill 1 + result = v.div(subset, join="left") + assert result.sizes["dim_2"] == 20 assert result.coeffs.squeeze().sel(dim_2=1).item() == pytest.approx(0.1) assert result.coeffs.squeeze().sel(dim_2=0).item() == pytest.approx(1.0) + # --- Constraints with subset RHS --- + def test_var_le_subset(self, v: Variable, subset: xr.DataArray) -> None: - con = v <= subset - assert con.sizes["dim_2"] == v.sizes["dim_2"] - assert con.rhs.sel(dim_2=1).item() == 10.0 - assert con.rhs.sel(dim_2=3).item() == 30.0 - assert np.isnan(con.rhs.sel(dim_2=0).item()) + with pytest.raises(ValueError, match="exact"): + v <= subset - @pytest.mark.parametrize("sign", ["<=", ">=", "=="]) - def test_var_comparison_subset( - self, v: Variable, subset: xr.DataArray, sign: str - ) -> None: - if sign == "<=": - con = v <= subset - elif sign == ">=": - con = v >= subset - else: - con = v == subset - assert con.sizes["dim_2"] == v.sizes["dim_2"] + # explicit join="left": 20 entries, NaN where RHS missing + con = v.to_linexpr().le(subset, join="left") + assert con.sizes["dim_2"] == 20 assert con.rhs.sel(dim_2=1).item() == 10.0 + assert con.rhs.sel(dim_2=3).item() == 30.0 assert np.isnan(con.rhs.sel(dim_2=0).item()) def test_expr_le_subset(self, v: Variable, subset: xr.DataArray) -> None: expr = v + 5 - con = expr <= subset - assert con.sizes["dim_2"] == v.sizes["dim_2"] + with pytest.raises(ValueError, match="exact"): + expr <= subset + + con = expr.le(subset, join="left") + assert con.sizes["dim_2"] == 20 assert con.rhs.sel(dim_2=1).item() == pytest.approx(5.0) assert con.rhs.sel(dim_2=3).item() == pytest.approx(25.0) assert np.isnan(con.rhs.sel(dim_2=0).item()) - def test_add_commutativity_full_coords(self, v: Variable) -> None: - full = xr.DataArray( - np.arange(20, dtype=float), - dims=["dim_2"], - coords={"dim_2": range(20)}, - ) - assert_linequal(v + full, full + v) - - def test_superset_addition_pins_to_lhs( - self, v: Variable, superset: xr.DataArray + @pytest.mark.parametrize("sign", ["<=", ">=", "=="]) + def test_var_comparison_subset( + self, v: Variable, subset: xr.DataArray, sign: str ) -> None: - result = v + superset - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.const.values).any() - - def test_superset_add_var(self, v: Variable, superset: xr.DataArray) -> None: - assert_linequal(superset + v, v + superset) - - def test_superset_sub_var(self, v: Variable, superset: xr.DataArray) -> None: - assert_linequal(superset - v, -v + superset) + with pytest.raises(ValueError, match="exact"): + if sign == "<=": + v <= subset + elif sign == ">=": + v >= subset + else: + v == subset + + def test_constraint_le_join_inner(self, v: Variable, subset: xr.DataArray) -> None: + con = v.to_linexpr().le(subset, join="inner") + assert con.sizes["dim_2"] == 2 + assert con.rhs.sel(dim_2=1).item() == 10.0 + assert con.rhs.sel(dim_2=3).item() == 30.0 - def test_superset_mul_var(self, v: Variable, superset: xr.DataArray) -> None: - assert_linequal(superset * v, v * superset) + # --- Matching coordinates: unchanged behavior --- - @pytest.mark.parametrize("sign", ["<=", ">="]) - def test_superset_comparison_var( - self, v: Variable, superset: xr.DataArray, sign: str - ) -> None: - if sign == "<=": - con = superset <= v - else: - con = superset >= v - assert con.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(con.lhs.coeffs.values).any() - assert not np.isnan(con.rhs.values).any() - - def test_disjoint_addition_pins_to_lhs(self, v: Variable) -> None: - disjoint = xr.DataArray( - [100.0, 200.0], dims=["dim_2"], coords={"dim_2": [50, 60]} - ) - result = v + disjoint - assert result.sizes["dim_2"] == v.sizes["dim_2"] + def test_add_matching_unchanged(self, v: Variable, matching: xr.DataArray) -> None: + result = v + matching + assert result.sizes["dim_2"] == 20 assert not np.isnan(result.const.values).any() - np.testing.assert_array_equal(result.const.values, np.zeros(20)) - def test_expr_div_subset(self, v: Variable, subset: xr.DataArray) -> None: - expr = 1 * v - result = expr / subset - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() - assert result.coeffs.squeeze().sel(dim_2=1).item() == pytest.approx(0.1) - assert result.coeffs.squeeze().sel(dim_2=0).item() == pytest.approx(1.0) + def test_mul_matching_unchanged(self, v: Variable, matching: xr.DataArray) -> None: + result = v * matching + assert result.sizes["dim_2"] == 20 - def test_subset_add_var_coefficients( - self, v: Variable, subset: xr.DataArray - ) -> None: - result = subset + v - np.testing.assert_array_equal(result.coeffs.squeeze().values, np.ones(20)) + def test_le_matching_unchanged(self, v: Variable, matching: xr.DataArray) -> None: + con = v <= matching + assert con.sizes["dim_2"] == 20 - def test_subset_sub_var_coefficients( - self, v: Variable, subset: xr.DataArray + def test_add_commutativity_matching( + self, v: Variable, matching: xr.DataArray ) -> None: - result = subset - v - np.testing.assert_array_equal(result.coeffs.squeeze().values, -np.ones(20)) + assert_linequal(v + matching, matching + v) - @pytest.mark.parametrize("sign", ["<=", ">=", "=="]) - def test_subset_comparison_var( - self, v: Variable, subset: xr.DataArray, sign: str - ) -> None: - if sign == "<=": - con = subset <= v - elif sign == ">=": - con = subset >= v - else: - con = subset == v - assert con.sizes["dim_2"] == v.sizes["dim_2"] - assert np.isnan(con.rhs.sel(dim_2=0).item()) - assert con.rhs.sel(dim_2=1).item() == pytest.approx(10.0) + def test_mul_commutativity(self, v: Variable, subset: xr.DataArray) -> None: + assert_linequal(v * subset, subset * v) - def test_superset_mul_pins_to_lhs( - self, v: Variable, superset: xr.DataArray - ) -> None: - result = v * superset - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() + # --- Explicit join modes --- - def test_superset_div_pins_to_lhs(self, v: Variable) -> None: - superset_nonzero = xr.DataArray( - np.arange(1, 26, dtype=float), - dims=["dim_2"], - coords={"dim_2": range(25)}, + def test_add_join_inner(self, v: Variable, subset: xr.DataArray) -> None: + result = v.add(subset, join="inner") + assert result.sizes["dim_2"] == 2 + assert result.const.sel(dim_2=1).item() == 10.0 + assert result.const.sel(dim_2=3).item() == 30.0 + + def test_add_join_outer(self, v: Variable, subset: xr.DataArray) -> None: + result = v.add(subset, join="outer") + assert result.sizes["dim_2"] == 20 + assert result.const.sel(dim_2=1).item() == 10.0 + assert result.const.sel(dim_2=0).item() == 0.0 + + def test_add_positional_assign_coords(self, v: Variable) -> None: + disjoint = xr.DataArray( + np.ones(20), dims=["dim_2"], coords={"dim_2": range(50, 70)} ) - result = v / superset_nonzero - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() + result = v + disjoint.assign_coords(dim_2=v.coords["dim_2"]) + assert result.sizes["dim_2"] == 20 + assert list(result.coords["dim_2"].values) == list(range(20)) + + # --- Quadratic expressions --- def test_quadexpr_add_subset( self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray ) -> None: qexpr = v * v - result = qexpr + subset - assert isinstance(result, QuadraticExpression) - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.const.values).any() - np.testing.assert_array_equal(result.const.values, expected_fill) + with pytest.raises(ValueError, match="exact"): + qexpr + subset - def test_quadexpr_sub_subset( - self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray - ) -> None: - qexpr = v * v - result = qexpr - subset + result = qexpr.add(subset, join="left") assert isinstance(result, QuadraticExpression) - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.const.values).any() - np.testing.assert_array_equal(result.const.values, -expected_fill) + assert result.sizes["dim_2"] == 20 + np.testing.assert_array_equal(result.const.values, expected_fill) def test_quadexpr_mul_subset( self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray ) -> None: qexpr = v * v + # inner join: 2 entries result = qexpr * subset assert isinstance(result, QuadraticExpression) - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() - np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + assert result.sizes["dim_2"] == 2 - def test_subset_mul_quadexpr( - self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray - ) -> None: - qexpr = v * v - result = subset * qexpr + # explicit join="left": 20 entries + result = qexpr.mul(subset, join="left") assert isinstance(result, QuadraticExpression) - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() + assert result.sizes["dim_2"] == 20 np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) - def test_subset_add_quadexpr(self, v: Variable, subset: xr.DataArray) -> None: - qexpr = v * v - assert_quadequal(subset + qexpr, qexpr + subset) + # --- Multi-dimensional --- def test_multidim_subset_mul(self, m: Model) -> None: coords_a = pd.RangeIndex(4, name="a") coords_b = pd.RangeIndex(5, name="b") w = m.add_variables(coords=[coords_a, coords_b], name="w") - subset_2d = xr.DataArray( [[2.0, 3.0], [4.0, 5.0]], dims=["a", "b"], coords={"a": [1, 3], "b": [0, 4]}, ) + + # inner join: 2x2 result = w * subset_2d + assert result.sizes["a"] == 2 + assert result.sizes["b"] == 2 + + # explicit join="left": 4x5, zeros at non-subset positions + result = w.mul(subset_2d, join="left") assert result.sizes["a"] == 4 assert result.sizes["b"] == 5 - assert not np.isnan(result.coeffs.values).any() assert result.coeffs.squeeze().sel(a=1, b=0).item() == pytest.approx(2.0) assert result.coeffs.squeeze().sel(a=3, b=4).item() == pytest.approx(5.0) assert result.coeffs.squeeze().sel(a=0, b=0).item() == pytest.approx(0.0) - assert result.coeffs.squeeze().sel(a=1, b=2).item() == pytest.approx(0.0) def test_multidim_subset_add(self, m: Model) -> None: coords_a = pd.RangeIndex(4, name="a") coords_b = pd.RangeIndex(5, name="b") w = m.add_variables(coords=[coords_a, coords_b], name="w") - subset_2d = xr.DataArray( [[2.0, 3.0], [4.0, 5.0]], dims=["a", "b"], coords={"a": [1, 3], "b": [0, 4]}, ) - result = w + subset_2d - assert result.sizes["a"] == 4 - assert result.sizes["b"] == 5 - assert not np.isnan(result.const.values).any() - assert result.const.sel(a=1, b=0).item() == pytest.approx(2.0) - assert result.const.sel(a=3, b=4).item() == pytest.approx(5.0) - assert result.const.sel(a=0, b=0).item() == pytest.approx(0.0) + + with pytest.raises(ValueError, match="exact"): + w + subset_2d + + # --- Edge cases --- def test_constraint_rhs_extra_dims_raises(self, v: Variable) -> None: rhs = xr.DataArray( @@ -888,29 +900,23 @@ def test_constraint_rhs_extra_dims_raises(self, v: Variable) -> None: with pytest.raises(ValueError, match="not present in the expression"): v <= rhs + def test_add_constant_extra_dims_raises(self, v: Variable) -> None: + da = xr.DataArray( + [[1.0, 2.0]], dims=["extra", "dim_2"], coords={"dim_2": [0, 1]} + ) + with pytest.raises(ValueError, match="not present in the expression"): + v + da + with pytest.raises(ValueError, match="not present in the expression"): + v - da + # multiplication still allows extra dims (broadcasts) + result = v * da + assert "extra" in result.dims + def test_da_truediv_var_raises(self, v: Variable) -> None: da = xr.DataArray(np.ones(20), dims=["dim_2"], coords={"dim_2": range(20)}) with pytest.raises(TypeError): da / v # type: ignore[operator] - def test_disjoint_mul_produces_zeros(self, v: Variable) -> None: - disjoint = xr.DataArray( - [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} - ) - result = v * disjoint - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() - np.testing.assert_array_equal(result.coeffs.squeeze().values, np.zeros(20)) - - def test_disjoint_div_preserves_coeffs(self, v: Variable) -> None: - disjoint = xr.DataArray( - [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} - ) - result = v / disjoint - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() - np.testing.assert_array_equal(result.coeffs.squeeze().values, np.ones(20)) - def test_da_eq_da_still_works(self) -> None: da1 = xr.DataArray([1, 2, 3]) da2 = xr.DataArray([1, 2, 3]) @@ -931,7 +937,8 @@ def test_subset_constraint_solve_integration(self) -> None: coords = pd.RangeIndex(5, name="i") x = m.add_variables(lower=0, upper=100, coords=[coords], name="x") subset_ub = xr.DataArray([10.0, 20.0], dims=["i"], coords={"i": [1, 3]}) - m.add_constraints(x <= subset_ub, name="subset_ub") + # exact default raises — use explicit join="left" (NaN = no constraint) + m.add_constraints(x.to_linexpr().le(subset_ub, join="left"), name="subset_ub") m.add_objective(x.sum(), sense="max") m.solve(solver_name=available_solvers[0]) sol = m.solution["x"] @@ -1789,10 +1796,12 @@ def b(self, m2: Model) -> Variable: def c(self, m2: Model) -> Variable: return m2.variables["c"] - def test_add_join_none_preserves_default(self, a: Variable, b: Variable) -> None: - result_default = a.to_linexpr() + b.to_linexpr() - result_none = a.to_linexpr().add(b.to_linexpr(), join=None) - assert_linequal(result_default, result_none) + def test_add_join_none_raises_on_mismatch(self, a: Variable, b: Variable) -> None: + # a has i=[0,1,2], b has i=[1,2,3] — exact default raises + with pytest.raises(ValueError, match="exact"): + a.to_linexpr() + b.to_linexpr() + with pytest.raises(ValueError, match="exact"): + a.to_linexpr().add(b.to_linexpr(), join=None) def test_add_expr_join_inner(self, a: Variable, b: Variable) -> None: result = a.to_linexpr().add(b.to_linexpr(), join="inner") @@ -1820,10 +1829,10 @@ def test_add_constant_join_outer(self, a: Variable) -> None: result = a.to_linexpr().add(const, join="outer") assert list(result.data.indexes["i"]) == [0, 1, 2, 3] - def test_add_constant_join_override(self, a: Variable, c: Variable) -> None: + def test_add_constant_positional(self, a: Variable) -> None: expr = a.to_linexpr() - const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [0, 1, 2]}) - result = expr.add(const, join="override") + const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [5, 6, 7]}) + result = expr + const.assign_coords(i=expr.coords["i"]) assert list(result.data.indexes["i"]) == [0, 1, 2] assert (result.const.values == const.values).all() @@ -1880,8 +1889,8 @@ def test_merge_join_parameter(self, a: Variable, b: Variable) -> None: result: LinearExpression = merge([a.to_linexpr(), b.to_linexpr()], join="inner") assert list(result.data.indexes["i"]) == [1, 2] - def test_same_shape_add_join_override(self, a: Variable, c: Variable) -> None: - result = a.to_linexpr().add(c.to_linexpr(), join="override") + def test_same_shape_add_assign_coords(self, a: Variable, c: Variable) -> None: + result = a.to_linexpr() + c.to_linexpr().assign_coords(i=a.coords["i"]) assert list(result.data.indexes["i"]) == [0, 1, 2] def test_add_expr_outer_const_values(self, a: Variable, b: Variable) -> None: @@ -1919,17 +1928,17 @@ def test_add_constant_inner_fill_values(self, a: Variable) -> None: assert list(result.coords["i"].values) == [1] assert result.const.sel(i=1).item() == 15 - def test_add_constant_override_positional(self, a: Variable) -> None: + def test_add_constant_positional_different_coords(self, a: Variable) -> None: expr = 1 * a + 5 other = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [5, 6, 7]}) - result = expr.add(other, join="override") + result = expr + other.assign_coords(i=expr.coords["i"]) assert list(result.coords["i"].values) == [0, 1, 2] np.testing.assert_array_equal(result.const.values, [15, 25, 35]) - def test_sub_constant_override(self, a: Variable) -> None: + def test_sub_constant_positional(self, a: Variable) -> None: expr = 1 * a + 5 other = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [5, 6, 7]}) - result = expr.sub(other, join="override") + result = expr - other.assign_coords(i=expr.coords["i"]) assert list(result.coords["i"].values) == [0, 1, 2] np.testing.assert_array_equal(result.const.values, [-5, -15, -25]) @@ -1943,10 +1952,10 @@ def test_sub_expr_outer_const_values(self, a: Variable, b: Variable) -> None: assert result.const.sel(i=2).item() == -5 assert result.const.sel(i=3).item() == -10 - def test_mul_constant_override_positional(self, a: Variable) -> None: + def test_mul_constant_positional(self, a: Variable) -> None: expr = 1 * a + 5 other = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [5, 6, 7]}) - result = expr.mul(other, join="override") + result = expr * other.assign_coords(i=expr.coords["i"]) assert list(result.coords["i"].values) == [0, 1, 2] np.testing.assert_array_equal(result.const.values, [10, 15, 20]) np.testing.assert_array_equal(result.coeffs.squeeze().values, [2, 3, 4]) @@ -1963,10 +1972,10 @@ def test_mul_constant_outer_fill_values(self, a: Variable) -> None: assert result.coeffs.squeeze().sel(i=1).item() == 2 assert result.coeffs.squeeze().sel(i=0).item() == 0 - def test_div_constant_override_positional(self, a: Variable) -> None: + def test_div_constant_positional(self, a: Variable) -> None: expr = 1 * a + 10 other = xr.DataArray([2.0, 5.0, 10.0], dims=["i"], coords={"i": [5, 6, 7]}) - result = expr.div(other, join="override") + result = expr / other.assign_coords(i=expr.coords["i"]) assert list(result.coords["i"].values) == [0, 1, 2] np.testing.assert_array_equal(result.const.values, [5.0, 2.0, 1.0]) @@ -1990,16 +1999,16 @@ def test_variable_add_outer_values(self, a: Variable, b: Variable) -> None: assert set(result.coords["i"].values) == {0, 1, 2, 3} assert result.nterm == 2 - def test_variable_mul_override(self, a: Variable) -> None: + def test_variable_mul_positional(self, a: Variable) -> None: other = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [5, 6, 7]}) - result = a.mul(other, join="override") + result = a * other.assign_coords(i=a.coords["i"]) assert isinstance(result, LinearExpression) assert list(result.coords["i"].values) == [0, 1, 2] np.testing.assert_array_equal(result.coeffs.squeeze().values, [2, 3, 4]) - def test_variable_div_override(self, a: Variable) -> None: + def test_variable_div_positional(self, a: Variable) -> None: other = xr.DataArray([2.0, 5.0, 10.0], dims=["i"], coords={"i": [5, 6, 7]}) - result = a.div(other, join="override") + result = a / other.assign_coords(i=a.coords["i"]) assert isinstance(result, LinearExpression) assert list(result.coords["i"].values) == [0, 1, 2] np.testing.assert_array_almost_equal( @@ -2013,14 +2022,18 @@ def test_merge_outer_join(self, a: Variable, b: Variable) -> None: def test_add_same_coords_all_joins(self, a: Variable, c: Variable) -> None: expr_a = 1 * a + 5 const = xr.DataArray([1, 2, 3], dims=["i"], coords={"i": [0, 1, 2]}) - for join in ["override", "outer", "inner"]: + for join in ["outer", "inner"]: result = expr_a.add(const, join=join) assert list(result.coords["i"].values) == [0, 1, 2] np.testing.assert_array_equal(result.const.values, [6, 7, 8]) + # assign_coords also works when coords already match + result = expr_a + const.assign_coords(i=expr_a.coords["i"]) + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.const.values, [6, 7, 8]) - def test_add_scalar_with_explicit_join(self, a: Variable) -> None: + def test_add_scalar(self, a: Variable) -> None: expr = 1 * a + 5 - result = expr.add(10, join="override") + result = expr + 10 np.testing.assert_array_equal(result.const.values, [15, 15, 15]) assert list(result.coords["i"].values) == [0, 1, 2] @@ -2028,10 +2041,10 @@ def test_quadratic_add_constant_join_inner(self, a: Variable, b: Variable) -> No quad = a.to_linexpr() * b.to_linexpr() const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 2, 3]}) result = quad.add(const, join="inner") - assert list(result.data.indexes["i"]) == [1, 2, 3] + assert list(result.data.indexes["i"]) == [1, 2] - def test_quadratic_add_expr_join_inner(self, a: Variable) -> None: - quad = a.to_linexpr() * a.to_linexpr() + def test_quadratic_add_expr_join_inner(self, a: Variable, b: Variable) -> None: + quad = a.to_linexpr() * b.to_linexpr() const = xr.DataArray([10, 20], dims=["i"], coords={"i": [0, 1]}) result = quad.add(const, join="inner") assert list(result.data.indexes["i"]) == [0, 1] @@ -2040,7 +2053,7 @@ def test_quadratic_mul_constant_join_inner(self, a: Variable, b: Variable) -> No quad = a.to_linexpr() * b.to_linexpr() const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) result = quad.mul(const, join="inner") - assert list(result.data.indexes["i"]) == [1, 2, 3] + assert list(result.data.indexes["i"]) == [1, 2] def test_merge_join_left(self, a: Variable, b: Variable) -> None: result: LinearExpression = merge([a.to_linexpr(), b.to_linexpr()], join="left") diff --git a/test/test_optimization.py b/test/test_optimization.py index 492d703a..6bcb1627 100644 --- a/test/test_optimization.py +++ b/test/test_optimization.py @@ -186,8 +186,8 @@ def model_with_non_aligned_variables() -> Model: lower = pd.Series(0, range(8)) y = m.add_variables(lower=lower, coords=[lower.index], name="y") - m.add_constraints(x + y, GREATER_EQUAL, 10.5) - m.objective = 1 * x + 0.5 * y + m.add_constraints(x.add(y, join="outer"), GREATER_EQUAL, 10.5) + m.objective = x.add(0.5 * y, join="outer") return m diff --git a/test/test_typing.py b/test/test_typing.py index 99a27033..312f76c9 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -7,6 +7,7 @@ def test_operations_with_data_arrays_are_typed_correctly() -> None: m = linopy.Model() a: xr.DataArray = xr.DataArray([1, 2, 3]) + s: xr.DataArray = xr.DataArray(5.0) v: linopy.Variable = m.add_variables(lower=0.0, name="v") e: linopy.LinearExpression = v * 1.0 @@ -14,12 +15,12 @@ def test_operations_with_data_arrays_are_typed_correctly() -> None: _ = a * v _ = v * a - _ = v + a + _ = v + s _ = a * e _ = e * a - _ = e + a + _ = e + s _ = a * q _ = q * a - _ = q + a + _ = q + s