Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ jobs:
fail-fast: false
matrix:
python-version:
- "3.10"
- "3.11"
- "3.12"
- "3.13"
- "3.14"
Expand All @@ -32,11 +34,7 @@ jobs:
uses: astral-sh/setup-uv@v1

- name: Install dependencies
run: |
uv sync --dev
uv tool install ruff
uv tool install pyrefly
uv tool install mypy
run: uv sync --dev

- name: Ruff format check
run: uv run --dev ruff format --check --diff .
Expand All @@ -58,7 +56,7 @@ jobs:
# Unset it so pytest can discover pytest-cov and pytest-benchmark via entry points,
# which are required by the addopts configured in pyproject.toml.
unset PYTEST_DISABLE_PLUGIN_AUTOLOAD
uv run --dev pytest -vv --cov=alternative --cov-report=xml --junitxml=test-results.xml
uv run --dev pytest -vv --cov=alternative --cov-report=xml --cov-fail-under=100 --junitxml=test-results.xml

- name: Upload coverage to Codecov
if: (success() || hashFiles('coverage.xml') != '') && env.CODECOV_TOKEN != ''
Expand All @@ -76,4 +74,3 @@ jobs:
files: ./test-results.xml
report_type: test_results
fail_ci_if_error: true

3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ htmlcov/
.coverage.*
.cache
nosetests.xml
test-results.xml
coverage.xml
*.cover
*.py,cover
Expand Down Expand Up @@ -191,4 +192,4 @@ cython_debug/
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
.cursorindexingignore
11 changes: 9 additions & 2 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
The repository defines testing via GitHub actions. When contributing:

* Run the same steps locally that the workflow performs. This includes running `uv run --dev pytest`, `uv run --dev ruff format --check .`, `pyrefly check .`, and `mypy .`.
* Ensure code is formatted with `ruff format` before committing.
* You must check the changes are correct using the same commands as the workflow:
* `uv sync --dev`
* `uv run --dev ruff format --check --diff .`
* `uv run --dev ruff check .`
* `uv run --dev pyrefly check .`
* `uv run --dev mypy .`
* `uv run --dev pytest -vv --cov=alternative --cov-report=xml --cov-fail-under=100 --junitxml=test-results.xml`
* Format code with `uv run --dev ruff format .` before committing.
* Any change to branching paths in `alternative.py` must be followed by a branch coverage run and review for material missing runtime coverage using `uv run --dev pytest --cov=alternative --cov-branch --cov-report=term-missing:skip-covered`.
* Name tests and functions in `snake_case` and give them triple-quoted docstrings similar to the current codebase.
145 changes: 84 additions & 61 deletions alternative.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,17 @@
import inspect
import os
from functools import wraps, lru_cache
from typing import Callable, Protocol
from typing import cast, overload
from typing import (
Any,
Callable,
Final,
Generic,
ParamSpec,
Protocol,
TypeVar,
cast,
overload,
)


DEBUG = os.environ.get("ALTERNATIVE_DEBUG", "0").lower() in (
Expand All @@ -28,22 +37,20 @@
]


class _UNDEFINED: ...
class _Undefined:
"""Sentinel type used when an optional decorator argument is omitted."""


class _SupportsLessThan(Protocol):
def __lt__(self, other: object, /) -> bool: ...


_UNDEFINED_VALUE = _UNDEFINED()
_UNDEFINED_VALUE: Final = _Undefined()

type ImplementationSig[**P, R] = Callable[P, R] | Implementation[P, R]
type AlternativesWrapper[**P, R] = Callable[
[ImplementationSig[P, R]], Alternatives[P, R]
]
type ImplementationWrapper[**P, R] = Callable[
[ImplementationSig[P, R]], Implementation[P, R]
]
P = ParamSpec("P")
R = TypeVar("R")
M = TypeVar("M")
F = TypeVar("F", bound=Callable[..., Any])


class AlternativeError(Exception):
Expand Down Expand Up @@ -101,7 +108,7 @@ def _maybe_get_caller_path() -> str | None:
return None


class Alternatives[**P, R]:
class Alternatives(Generic[P, R]):
def __init__(self, implementation: Callable[P, R], *, default: bool = False):
imp = Implementation(self, implementation, label=_maybe_get_caller_path())
self.reference = imp
Expand All @@ -125,19 +132,25 @@ def __init__(self, implementation: Callable[P, R], *, default: bool = False):

@overload
def add(
self, implementation: _UNDEFINED = _UNDEFINED_VALUE, *, default: bool = False
) -> ImplementationWrapper[P, R]: ...
self, *, default: bool = False
) -> Callable[[Callable[P, R] | Implementation[P, R]], Implementation[P, R]]: ...
@overload
def add(
self, implementation: ImplementationSig[P, R], *, default: bool = False
self,
implementation: Callable[P, R] | Implementation[P, R],
*,
default: bool = False,
) -> Implementation[P, R]: ...

def add(
self,
implementation=_UNDEFINED_VALUE,
implementation: Callable[P, R] | _Undefined = _UNDEFINED_VALUE,
*,
default=False,
) -> Implementation[P, R] | ImplementationWrapper[P, R]:
default: bool = False,
) -> (
Implementation[P, R]
| Callable[[Callable[P, R] | Implementation[P, R]], Implementation[P, R]]
):
if self._implementations_used:
# avoid surprises from implementation changes after selection/inspection
if DEBUG:
Expand All @@ -146,14 +159,14 @@ def add(
msg = None
raise AddTooLateError(msg)

if isinstance(implementation, _UNDEFINED):
if isinstance(implementation, _Undefined):

def wrapper(
implementation: ImplementationSig[P, R],
implementation: Callable[P, R] | Implementation[P, R],
) -> Implementation[P, R]:
return self.add(implementation, default=default)

return cast(ImplementationWrapper[P, R], wrapper)
return wrapper

label = _maybe_get_caller_path()
if not isinstance(implementation, Implementation):
Expand Down Expand Up @@ -210,7 +223,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
# this method will only be called at most once as self.callable overwrites self.__call__
return self.callable(*args, **kwargs)

def measure[M](
def measure(
self, /, operator: Callable[[R], M], *args: P.args, **kwargs: P.kwargs
) -> dict[Implementation[P, R], M]:
"""Invoke each implementation with the given parameters, then evaluate their results with the operator.
Expand Down Expand Up @@ -241,33 +254,32 @@ def measure[M](
@overload
def pytest_parametrize(
self,
test: _UNDEFINED = _UNDEFINED_VALUE,
*,
only_default: bool = False,
): ...
) -> Callable[[F], F]: ...
@overload
def pytest_parametrize(
self,
test: Callable,
test: F,
*,
only_default: bool = False,
): ...
) -> F: ...
def pytest_parametrize(
self,
test=_UNDEFINED_VALUE,
test: F | _Undefined = _UNDEFINED_VALUE,
*,
only_default: bool = False,
):
) -> F | Callable[[F], F]:
"""Decorator to parametrise a test function with implementations - always includes the reference implementation.

:param test: Test function to wrap - this is elided if using the decorator syntax.
:parameter only_default: Only include the reference and default implementations. If False, include all implementations.
"""
import pytest

if isinstance(test, _UNDEFINED):
if isinstance(test, _Undefined):

def decorator(f: Callable):
def decorator(f: F) -> F:
return self.pytest_parametrize(f, only_default=only_default)

return decorator
Expand All @@ -278,38 +290,37 @@ def decorator(f: Callable):

@pytest.mark.parametrize("implementation", implementations)
@wraps(test)
def inner(*args, **kwargs):
def inner(*args: Any, **kwargs: Any) -> Any:
return test(*args, **kwargs)

return inner
return cast(F, inner)

@overload
def pytest_parametrize_pairs(
self,
test: _UNDEFINED = _UNDEFINED_VALUE,
*,
n_cache: int | None = 0,
double_reference: bool = False,
only_default: bool = False,
): ...
) -> Callable[[F], F]: ...
@overload
def pytest_parametrize_pairs(
self,
test: Callable,
test: F,
*,
n_cache: int | None = 0,
double_reference: bool = False,
only_default: bool = False,
): ...
) -> F: ...

def pytest_parametrize_pairs(
self,
test=_UNDEFINED_VALUE,
test: F | _Undefined = _UNDEFINED_VALUE,
*,
n_cache=0,
double_reference=False,
only_default=False,
):
n_cache: int | None = 0,
double_reference: bool = False,
only_default: bool = False,
) -> F | Callable[[F], F]:
"""Decorator to parametrise a test function with the reference and alternative implementations.

:parameter test: Inner pytest function to parameterise with reference and alternative implementations - this is elided if using the decorator syntax.
Expand All @@ -322,9 +333,9 @@ def pytest_parametrize_pairs(
"""
import pytest

if isinstance(test, _UNDEFINED):
if isinstance(test, _Undefined):

def decorator(f: Callable):
def decorator(f: F) -> F:
return self.pytest_parametrize_pairs(
f,
n_cache=n_cache,
Expand All @@ -334,8 +345,9 @@ def decorator(f: Callable):

return decorator

reference_implementation = lru_cache(maxsize=n_cache)(
self.reference.implementation
reference_implementation = cast(
Callable[P, R],
lru_cache(maxsize=n_cache)(self.reference.implementation),
)

implementations = self._select_parametrize_pairs(
Expand All @@ -347,10 +359,10 @@ def decorator(f: Callable):
@pytest.mark.parametrize("reference", [reference_implementation])
@pytest.mark.parametrize("implementation", implementations)
@wraps(test)
def inner(*args, **kwargs):
def inner(*args: Any, **kwargs: Any) -> Any:
return test(*args, **kwargs)

return inner
return cast(F, inner)

def _select_parametrize_implementations(
self, *, only_default: bool
Expand Down Expand Up @@ -387,7 +399,7 @@ def _select_parametrize_pairs(


@dataclasses.dataclass(unsafe_hash=True)
class Implementation[**P, R]:
class Implementation(Generic[P, R]):
alternatives: Alternatives[P, R]
implementation: Callable[P, R]
label: str | None = None
Expand All @@ -410,43 +422,54 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:

@overload
def add(
self, implementation: _UNDEFINED = _UNDEFINED_VALUE, *, default: bool = False
) -> ImplementationWrapper[P, R]: ...
self, *, default: bool = False
) -> Callable[[Callable[P, R] | Implementation[P, R]], Implementation[P, R]]: ...
@overload
def add(
self, implementation: ImplementationSig[P, R], *, default: bool = False
self,
implementation: Callable[P, R] | Implementation[P, R],
*,
default: bool = False,
) -> Implementation[P, R]: ...

def add(
self, implementation=_UNDEFINED_VALUE, *, default=False
) -> Implementation[P, R] | ImplementationWrapper[P, R]:
self,
implementation: Callable[P, R] | _Undefined = _UNDEFINED_VALUE,
*,
default: bool = False,
) -> (
Implementation[P, R]
| Callable[[Callable[P, R] | Implementation[P, R]], Implementation[P, R]]
):
"""Add an alternative implementation."""
if isinstance(implementation, _Undefined):
return self.alternatives.add(default=default)
return self.alternatives.add(implementation, default=default)


@overload
def reference[**P, R](
implementation: _UNDEFINED = _UNDEFINED_VALUE, *, default: bool = False
def reference(
*, default: bool = False
) -> Callable[[Callable[P, R]], Alternatives[P, R]]: ...


@overload
def reference[**P, R]( # pyrefly: ignore[inconsistent-overload]
def reference(
implementation: Callable[P, R], *, default: bool = False
) -> Alternatives[P, R]: ...


def reference[**P, R](
implementation=_UNDEFINED_VALUE,
def reference(
implementation: Callable[P, R] | _Undefined = _UNDEFINED_VALUE,
*,
default=False,
default: bool = False,
) -> Alternatives[P, R] | Callable[[Callable[P, R]], Alternatives[P, R]]:
if isinstance(implementation, _UNDEFINED):
if isinstance(implementation, _Undefined):

def inner(f: Callable[P, R]) -> Alternatives[P, R]:
"""Add the reference implementation to the alternatives"""
return Alternatives(f, default=default)

return cast(AlternativesWrapper[P, R], inner)
return inner
else:
return Alternatives(implementation, default=default)
Loading
Loading