Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def solve(
remote: RemoteHandler | OetcHandler = None, # type: ignore
progress: bool | None = None,
mock_solve: bool = False,
reformulate_sos: bool = False,
reformulate_sos: bool | Literal["auto"] = False,
**solver_options: Any,
) -> tuple[str, str]:
"""
Expand Down Expand Up @@ -1297,9 +1297,12 @@ def solve(
than 10000 variables and constraints.
mock_solve : bool, optional
Whether to run a mock solve. This will skip the actual solving. Variables will be set to have dummy values
reformulate_sos : bool, optional
reformulate_sos : bool | Literal["auto"], optional
Whether to automatically reformulate SOS constraints as binary + linear
constraints for solvers that don't support them natively.
If True, always reformulates (warns if solver supports SOS natively).
If "auto", silently reformulates only when the solver lacks SOS support.
If False, raises if solver doesn't support SOS.
This uses the Big-M method and requires all SOS variables to have finite bounds.
Default is False.
**solver_options : kwargs
Expand Down Expand Up @@ -1399,24 +1402,27 @@ def solve(
f"Solver {solver_name} does not support quadratic problems."
)

if reformulate_sos not in (True, False, "auto"):
raise ValueError(
f"Invalid value for reformulate_sos: {reformulate_sos!r}. "
"Must be True, False, or 'auto'."
)

sos_reform_result = None
if self.variables.sos:
if reformulate_sos and not solver_supports(
solver_name, SolverFeature.SOS_CONSTRAINTS
):
supports_sos = solver_supports(solver_name, SolverFeature.SOS_CONSTRAINTS)
if reformulate_sos in (True, "auto") and not supports_sos:
logger.info(f"Reformulating SOS constraints for solver {solver_name}")
sos_reform_result = reformulate_sos_constraints(self)
elif reformulate_sos and solver_supports(
solver_name, SolverFeature.SOS_CONSTRAINTS
):
elif reformulate_sos is True and supports_sos:
logger.warning(
f"Solver {solver_name} supports SOS natively; "
"reformulate_sos=True is ignored."
)
elif not solver_supports(solver_name, SolverFeature.SOS_CONSTRAINTS):
elif reformulate_sos is False and not supports_sos:
raise ValueError(
f"Solver {solver_name} does not support SOS constraints. "
"Use reformulate_sos=True or a solver that supports SOS (gurobi, cplex)."
"Use reformulate_sos=True or 'auto', or a solver that supports SOS (gurobi, cplex)."
)

try:
Expand Down
14 changes: 8 additions & 6 deletions linopy/piecewise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Literal

import numpy as np
Expand Down Expand Up @@ -58,7 +58,7 @@ def _dict_to_array(d: dict[str, list[float]], dim: str, bp_dim: str) -> DataArra


def _segments_list_to_array(
values: list[list[float]], bp_dim: str, seg_dim: str
values: list[Sequence[float]], bp_dim: str, seg_dim: str
) -> DataArray:
max_len = max(len(seg) for seg in values)
data = np.full((len(values), max_len), np.nan)
Expand All @@ -72,7 +72,7 @@ def _segments_list_to_array(


def _dict_segments_to_array(
d: dict[str, list[list[float]]], dim: str, bp_dim: str, seg_dim: str
d: dict[str, list[Sequence[float]]], dim: str, bp_dim: str, seg_dim: str
) -> DataArray:
parts = []
for key, seg_list in d.items():
Expand Down Expand Up @@ -138,7 +138,9 @@ def _resolve_kwargs(


def _resolve_segment_kwargs(
kwargs: dict[str, list[list[float]] | dict[str, list[list[float]]] | DataArray],
kwargs: dict[
str, list[Sequence[float]] | dict[str, list[Sequence[float]]] | DataArray
],
dim: str | None,
bp_dim: str,
seg_dim: str,
Expand Down Expand Up @@ -235,13 +237,13 @@ def __call__(

def segments(
self,
values: list[list[float]] | dict[str, list[list[float]]] | None = None,
values: list[Sequence[float]] | dict[str, list[Sequence[float]]] | None = None,
*,
dim: str | None = None,
bp_dim: str = DEFAULT_BREAKPOINT_DIM,
seg_dim: str = DEFAULT_SEGMENT_DIM,
link_dim: str = DEFAULT_LINK_DIM,
**kwargs: list[list[float]] | dict[str, list[list[float]]] | DataArray,
**kwargs: list[Sequence[float]] | dict[str, list[Sequence[float]]] | DataArray,
) -> DataArray:
"""
Create a segmented breakpoint DataArray for disjunctive piecewise constraints.
Expand Down
117 changes: 116 additions & 1 deletion test/test_sos_reformulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from __future__ import annotations

import logging

import numpy as np
import pandas as pd
import pytest

from linopy import Model, available_solvers
from linopy import Model, Variable, available_solvers
from linopy.constants import SOS_TYPE_ATTR
from linopy.sos_reformulation import (
compute_big_m_values,
Expand Down Expand Up @@ -816,3 +818,116 @@ def test_sos1_unsorted_coords(self) -> None:

assert m.objective.value is not None
assert np.isclose(m.objective.value, 3, atol=1e-5)


@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed")
class TestAutoReformulation:
"""Tests for reformulate_sos='auto' functionality."""

@pytest.fixture()
def sos1_model(self) -> tuple[Model, Variable]:
m = Model()
idx = pd.Index([0, 1, 2], name="i")
x = m.add_variables(lower=0, upper=1, coords=[idx], name="x")
m.add_sos_constraints(x, sos_type=1, sos_dim="i")
m.add_objective(x * np.array([1, 2, 3]), sense="max")
return m, x

def test_auto_reformulates_when_solver_lacks_sos(
self, sos1_model: tuple[Model, Variable]
) -> None:
m, x = sos1_model
m.solve(solver_name="highs", reformulate_sos="auto")

assert np.isclose(x.solution.values[2], 1, atol=1e-5)
assert np.isclose(x.solution.values[0], 0, atol=1e-5)
assert np.isclose(x.solution.values[1], 0, atol=1e-5)
assert m.objective.value is not None
assert np.isclose(m.objective.value, 3, atol=1e-5)

def test_auto_with_sos2(self) -> None:
m = Model()
idx = pd.Index([0, 1, 2, 3], name="i")
x = m.add_variables(lower=0, upper=1, coords=[idx], name="x")
m.add_sos_constraints(x, sos_type=2, sos_dim="i")
m.add_objective(x * np.array([10, 1, 1, 10]), sense="max")

m.solve(solver_name="highs", reformulate_sos="auto")

assert m.objective.value is not None
nonzero_indices = np.where(np.abs(x.solution.values) > 1e-5)[0]
assert len(nonzero_indices) <= 2
if len(nonzero_indices) == 2:
assert abs(nonzero_indices[1] - nonzero_indices[0]) == 1
assert not np.isclose(m.objective.value, 20, atol=1e-5)

def test_auto_emits_info_no_warning(
self, sos1_model: tuple[Model, Variable], caplog: pytest.LogCaptureFixture
) -> None:
m, _ = sos1_model

with caplog.at_level(logging.INFO):
m.solve(solver_name="highs", reformulate_sos="auto")

assert any("Reformulating SOS" in msg for msg in caplog.messages)
assert not any("supports SOS natively" in msg for msg in caplog.messages)

@pytest.mark.skipif(
"gurobi" not in available_solvers, reason="Gurobi not installed"
)
def test_auto_passes_through_native_sos_without_reformulation(self) -> None:
import gurobipy

m = Model()
idx = pd.Index([0, 1, 2], name="i")
x = m.add_variables(lower=0, upper=1, coords=[idx], name="x")
m.add_sos_constraints(x, sos_type=1, sos_dim="i")
m.add_objective(x * np.array([1, 2, 3]), sense="max")

try:
m.solve(solver_name="gurobi", reformulate_sos="auto")
except gurobipy.GurobiError as exc:
pytest.skip(f"Gurobi environment unavailable: {exc}")

assert m.objective.value is not None
assert np.isclose(m.objective.value, 3, atol=1e-5)
assert np.isclose(x.solution.values[2], 1, atol=1e-5)
assert np.isclose(x.solution.values[0], 0, atol=1e-5)
assert np.isclose(x.solution.values[1], 0, atol=1e-5)

def test_auto_multidimensional_sos1(self) -> None:
m = Model()
idx_i = pd.Index([0, 1, 2], name="i")
idx_j = pd.Index([0, 1], name="j")
x = m.add_variables(lower=0, upper=1, coords=[idx_i, idx_j], name="x")
m.add_sos_constraints(x, sos_type=1, sos_dim="i")
m.add_objective(x.sum(), sense="max")

m.solve(solver_name="highs", reformulate_sos="auto")

assert m.objective.value is not None
assert np.isclose(m.objective.value, 2, atol=1e-5)
for j in idx_j:
nonzero_count = (np.abs(x.solution.sel(j=j).values) > 1e-5).sum()
assert nonzero_count <= 1

def test_auto_noop_without_sos(self) -> None:
m = Model()
idx = pd.Index([0, 1, 2], name="i")
x = m.add_variables(lower=0, upper=1, coords=[idx], name="x")
m.add_objective(x.sum(), sense="max")

m.solve(solver_name="highs", reformulate_sos="auto")

assert m.objective.value is not None
assert np.isclose(m.objective.value, 3, atol=1e-5)

def test_invalid_reformulate_sos_value(self) -> None:
m = Model()
idx = pd.Index([0, 1, 2], name="i")
x = m.add_variables(lower=0, upper=1, coords=[idx], name="x")
m.add_sos_constraints(x, sos_type=1, sos_dim="i")
m.add_objective(x.sum(), sense="max")

with pytest.raises(ValueError, match="Invalid value for reformulate_sos"):
m.solve(solver_name="highs", reformulate_sos="invalid") # type: ignore[arg-type]