diff --git a/linopy/model.py b/linopy/model.py index 1901a4b9..7b8396f4 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1227,7 +1227,7 @@ def solve( remote: RemoteHandler | OetcHandler = None, # type: ignore progress: bool | None = None, mock_solve: bool = False, - reformulate_sos: bool = False, + reformulate_sos: bool | Literal["auto"] = False, **solver_options: Any, ) -> tuple[str, str]: """ @@ -1297,9 +1297,12 @@ def solve( than 10000 variables and constraints. mock_solve : bool, optional Whether to run a mock solve. This will skip the actual solving. Variables will be set to have dummy values - reformulate_sos : bool, optional + reformulate_sos : bool | Literal["auto"], optional Whether to automatically reformulate SOS constraints as binary + linear constraints for solvers that don't support them natively. + If True, always reformulates (warns if solver supports SOS natively). + If "auto", silently reformulates only when the solver lacks SOS support. + If False, raises if solver doesn't support SOS. This uses the Big-M method and requires all SOS variables to have finite bounds. Default is False. **solver_options : kwargs @@ -1399,24 +1402,27 @@ def solve( f"Solver {solver_name} does not support quadratic problems." ) + if reformulate_sos not in (True, False, "auto"): + raise ValueError( + f"Invalid value for reformulate_sos: {reformulate_sos!r}. " + "Must be True, False, or 'auto'." + ) + sos_reform_result = None if self.variables.sos: - if reformulate_sos and not solver_supports( - solver_name, SolverFeature.SOS_CONSTRAINTS - ): + supports_sos = solver_supports(solver_name, SolverFeature.SOS_CONSTRAINTS) + if reformulate_sos in (True, "auto") and not supports_sos: logger.info(f"Reformulating SOS constraints for solver {solver_name}") sos_reform_result = reformulate_sos_constraints(self) - elif reformulate_sos and solver_supports( - solver_name, SolverFeature.SOS_CONSTRAINTS - ): + elif reformulate_sos is True and supports_sos: logger.warning( f"Solver {solver_name} supports SOS natively; " "reformulate_sos=True is ignored." ) - elif not solver_supports(solver_name, SolverFeature.SOS_CONSTRAINTS): + elif reformulate_sos is False and not supports_sos: raise ValueError( f"Solver {solver_name} does not support SOS constraints. " - "Use reformulate_sos=True or a solver that supports SOS (gurobi, cplex)." + "Use reformulate_sos=True or 'auto', or a solver that supports SOS (gurobi, cplex)." ) try: diff --git a/linopy/piecewise.py b/linopy/piecewise.py index fd42bcc0..5128d1e5 100644 --- a/linopy/piecewise.py +++ b/linopy/piecewise.py @@ -7,7 +7,7 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Literal import numpy as np @@ -58,7 +58,7 @@ def _dict_to_array(d: dict[str, list[float]], dim: str, bp_dim: str) -> DataArra def _segments_list_to_array( - values: list[list[float]], bp_dim: str, seg_dim: str + values: list[Sequence[float]], bp_dim: str, seg_dim: str ) -> DataArray: max_len = max(len(seg) for seg in values) data = np.full((len(values), max_len), np.nan) @@ -72,7 +72,7 @@ def _segments_list_to_array( def _dict_segments_to_array( - d: dict[str, list[list[float]]], dim: str, bp_dim: str, seg_dim: str + d: dict[str, list[Sequence[float]]], dim: str, bp_dim: str, seg_dim: str ) -> DataArray: parts = [] for key, seg_list in d.items(): @@ -138,7 +138,9 @@ def _resolve_kwargs( def _resolve_segment_kwargs( - kwargs: dict[str, list[list[float]] | dict[str, list[list[float]]] | DataArray], + kwargs: dict[ + str, list[Sequence[float]] | dict[str, list[Sequence[float]]] | DataArray + ], dim: str | None, bp_dim: str, seg_dim: str, @@ -235,13 +237,13 @@ def __call__( def segments( self, - values: list[list[float]] | dict[str, list[list[float]]] | None = None, + values: list[Sequence[float]] | dict[str, list[Sequence[float]]] | None = None, *, dim: str | None = None, bp_dim: str = DEFAULT_BREAKPOINT_DIM, seg_dim: str = DEFAULT_SEGMENT_DIM, link_dim: str = DEFAULT_LINK_DIM, - **kwargs: list[list[float]] | dict[str, list[list[float]]] | DataArray, + **kwargs: list[Sequence[float]] | dict[str, list[Sequence[float]]] | DataArray, ) -> DataArray: """ Create a segmented breakpoint DataArray for disjunctive piecewise constraints. diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index f45ea706..24ba62b3 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -2,11 +2,13 @@ from __future__ import annotations +import logging + import numpy as np import pandas as pd import pytest -from linopy import Model, available_solvers +from linopy import Model, Variable, available_solvers from linopy.constants import SOS_TYPE_ATTR from linopy.sos_reformulation import ( compute_big_m_values, @@ -816,3 +818,116 @@ def test_sos1_unsorted_coords(self) -> None: assert m.objective.value is not None assert np.isclose(m.objective.value, 3, atol=1e-5) + + +@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") +class TestAutoReformulation: + """Tests for reformulate_sos='auto' functionality.""" + + @pytest.fixture() + def sos1_model(self) -> tuple[Model, Variable]: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x * np.array([1, 2, 3]), sense="max") + return m, x + + def test_auto_reformulates_when_solver_lacks_sos( + self, sos1_model: tuple[Model, Variable] + ) -> None: + m, x = sos1_model + m.solve(solver_name="highs", reformulate_sos="auto") + + assert np.isclose(x.solution.values[2], 1, atol=1e-5) + assert np.isclose(x.solution.values[0], 0, atol=1e-5) + assert np.isclose(x.solution.values[1], 0, atol=1e-5) + assert m.objective.value is not None + assert np.isclose(m.objective.value, 3, atol=1e-5) + + def test_auto_with_sos2(self) -> None: + m = Model() + idx = pd.Index([0, 1, 2, 3], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=2, sos_dim="i") + m.add_objective(x * np.array([10, 1, 1, 10]), sense="max") + + m.solve(solver_name="highs", reformulate_sos="auto") + + assert m.objective.value is not None + nonzero_indices = np.where(np.abs(x.solution.values) > 1e-5)[0] + assert len(nonzero_indices) <= 2 + if len(nonzero_indices) == 2: + assert abs(nonzero_indices[1] - nonzero_indices[0]) == 1 + assert not np.isclose(m.objective.value, 20, atol=1e-5) + + def test_auto_emits_info_no_warning( + self, sos1_model: tuple[Model, Variable], caplog: pytest.LogCaptureFixture + ) -> None: + m, _ = sos1_model + + with caplog.at_level(logging.INFO): + m.solve(solver_name="highs", reformulate_sos="auto") + + assert any("Reformulating SOS" in msg for msg in caplog.messages) + assert not any("supports SOS natively" in msg for msg in caplog.messages) + + @pytest.mark.skipif( + "gurobi" not in available_solvers, reason="Gurobi not installed" + ) + def test_auto_passes_through_native_sos_without_reformulation(self) -> None: + import gurobipy + + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x * np.array([1, 2, 3]), sense="max") + + try: + m.solve(solver_name="gurobi", reformulate_sos="auto") + except gurobipy.GurobiError as exc: + pytest.skip(f"Gurobi environment unavailable: {exc}") + + assert m.objective.value is not None + assert np.isclose(m.objective.value, 3, atol=1e-5) + assert np.isclose(x.solution.values[2], 1, atol=1e-5) + assert np.isclose(x.solution.values[0], 0, atol=1e-5) + assert np.isclose(x.solution.values[1], 0, atol=1e-5) + + def test_auto_multidimensional_sos1(self) -> None: + m = Model() + idx_i = pd.Index([0, 1, 2], name="i") + idx_j = pd.Index([0, 1], name="j") + x = m.add_variables(lower=0, upper=1, coords=[idx_i, idx_j], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x.sum(), sense="max") + + m.solve(solver_name="highs", reformulate_sos="auto") + + assert m.objective.value is not None + assert np.isclose(m.objective.value, 2, atol=1e-5) + for j in idx_j: + nonzero_count = (np.abs(x.solution.sel(j=j).values) > 1e-5).sum() + assert nonzero_count <= 1 + + def test_auto_noop_without_sos(self) -> None: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_objective(x.sum(), sense="max") + + m.solve(solver_name="highs", reformulate_sos="auto") + + assert m.objective.value is not None + assert np.isclose(m.objective.value, 3, atol=1e-5) + + def test_invalid_reformulate_sos_value(self) -> None: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x.sum(), sense="max") + + with pytest.raises(ValueError, match="Invalid value for reformulate_sos"): + m.solve(solver_name="highs", reformulate_sos="invalid") # type: ignore[arg-type]