From 608d4cc81e4a6d22536eebc2f2b4f7d59fb73e54 Mon Sep 17 00:00:00 2001 From: Oliver Bristow Date: Sat, 9 May 2026 21:36:29 +0100 Subject: [PATCH 1/8] Add method/classmethod/staticmethod support --- alternative.py | 120 +++++++++++++++++++++++++++++++++----------- docs/pytest.rst | 31 ++++++++++++ docs/quickstart.rst | 42 ++++++++++++++++ test_alternative.py | 104 ++++++++++++++++++++++++++++++++++++++ test_pytest_util.py | 39 ++++++++++++++ 5 files changed, 308 insertions(+), 28 deletions(-) diff --git a/alternative.py b/alternative.py index b571b35..322285a 100644 --- a/alternative.py +++ b/alternative.py @@ -45,6 +45,12 @@ class _SupportsLessThan(Protocol): def __lt__(self, other: object, /) -> bool: ... +class _Descriptor(Protocol): + def __get__( + self, instance: object | None, owner: type[Any] | None = None, / + ) -> Any: ... + + _UNDEFINED_VALUE: Final = _Undefined() P = ParamSpec("P") @@ -108,8 +114,55 @@ def _maybe_get_caller_path() -> str | None: return None +def _bind_implementation( + implementation: Any, + instance: object | None, + owner: type[Any] | None, +) -> Callable[P, R]: + """Bind an implementation using descriptor semantics when available.""" + if owner is None and instance is not None: + owner = type(instance) + + descriptor_get = getattr(implementation, "__get__", None) + if descriptor_get is not None and owner is not None: + return cast(Callable[P, R], descriptor_get(instance, owner)) + return cast(Callable[P, R], implementation) + + +@dataclasses.dataclass(frozen=True) +class _BoundAlternatives(Generic[P, R]): + alternatives: Alternatives[P, R] + instance: object | None + owner: type[Any] | None + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + implementation: Callable[P, R] = _bind_implementation( + self.alternatives.callable, self.instance, self.owner + ) + return implementation(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + return getattr(self.alternatives, name) + + +@dataclasses.dataclass(frozen=True) +class _BoundImplementation(Generic[P, R]): + implementation: Implementation[P, R] + instance: object | None + owner: type[Any] | None + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + implementation: Callable[P, R] = _bind_implementation( + self.implementation.implementation, self.instance, self.owner + ) + return implementation(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + return getattr(self.implementation, name) + + class Alternatives(Generic[P, R]): - def __init__(self, implementation: Callable[P, R], *, default: bool = False): + def __init__(self, implementation: Any, *, default: bool = False): imp = Implementation(self, implementation, label=_maybe_get_caller_path()) self.reference = imp # tracks the active implementation @@ -120,7 +173,7 @@ def __init__(self, implementation: Callable[P, R], *, default: bool = False): # tracks the use of the set should be self._enumerated = False - self._callable: Callable[P, R] | None = None + self._callable: Any | None = None self._debug_callable_used: str | None = None # beware the order of this depends on the sequence of imports, so may vary between entrypoints @@ -133,24 +186,21 @@ def __init__(self, implementation: Callable[P, R], *, default: bool = False): @overload def add( self, *, default: bool = False - ) -> Callable[[Callable[P, R] | Implementation[P, R]], Implementation[P, R]]: ... + ) -> Callable[[Any], Implementation[P, R]]: ... @overload def add( self, - implementation: Callable[P, R] | Implementation[P, R], + implementation: Any, *, default: bool = False, ) -> Implementation[P, R]: ... def add( self, - implementation: Callable[P, R] | _Undefined = _UNDEFINED_VALUE, + implementation: Any = _UNDEFINED_VALUE, *, default: bool = False, - ) -> ( - Implementation[P, R] - | Callable[[Callable[P, R] | Implementation[P, R]], Implementation[P, R]] - ): + ) -> Implementation[P, R] | Callable[[Any], Implementation[P, R]]: if self._implementations_used: # avoid surprises from implementation changes after selection/inspection if DEBUG: @@ -162,7 +212,7 @@ def add( if isinstance(implementation, _Undefined): def wrapper( - implementation: Callable[P, R] | Implementation[P, R], + implementation: Any, ) -> Implementation[P, R]: return self.add(implementation, default=default) @@ -196,7 +246,7 @@ def wrapper( return imp @property - def callable(self) -> Callable[P, R]: + def callable(self) -> Any: """Return the active implementation. Setting the default implementation is disabled after this is accessed.""" @@ -207,7 +257,6 @@ def callable(self) -> Callable[P, R]: else: self._callable = self.reference self._debug_callable_used = _maybe_get_caller_path() - setattr(self, "__call__", self._callable) # access the list of implementations to freeze them assert self.implementations return self._callable @@ -220,8 +269,13 @@ def implementations(self) -> list[Implementation[P, R]]: return self._implementations def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: - # this method will only be called at most once as self.callable overwrites self.__call__ - return self.callable(*args, **kwargs) + implementation: Callable[P, R] = _bind_implementation(self.callable, None, None) + return implementation(*args, **kwargs) + + def __get__( + self, instance: object | None, owner: type[Any] | None = None + ) -> _BoundAlternatives[P, R]: + return _BoundAlternatives(self, instance, owner) def measure( self, /, operator: Callable[[R], M], *args: P.args, **kwargs: P.kwargs @@ -401,7 +455,7 @@ def _select_parametrize_pairs( @dataclasses.dataclass(unsafe_hash=True) class Implementation(Generic[P, R]): alternatives: Alternatives[P, R] - implementation: Callable[P, R] + implementation: Any label: str | None = None def __post_init__(self): @@ -417,30 +471,34 @@ def __repr__(self) -> str: return f"Implementation({implementation_name})" def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: - setattr(self, "__call__", self.implementation) - return self.__call__(*args, **kwargs) + implementation: Callable[P, R] = _bind_implementation( + self.implementation, None, None + ) + return implementation(*args, **kwargs) + + def __get__( + self, instance: object | None, owner: type[Any] | None = None + ) -> _BoundImplementation[P, R]: + return _BoundImplementation(self, instance, owner) @overload def add( self, *, default: bool = False - ) -> Callable[[Callable[P, R] | Implementation[P, R]], Implementation[P, R]]: ... + ) -> Callable[[Any], Implementation[P, R]]: ... @overload def add( self, - implementation: Callable[P, R] | Implementation[P, R], + implementation: Any, *, default: bool = False, ) -> Implementation[P, R]: ... def add( self, - implementation: Callable[P, R] | _Undefined = _UNDEFINED_VALUE, + implementation: Any = _UNDEFINED_VALUE, *, default: bool = False, - ) -> ( - Implementation[P, R] - | Callable[[Callable[P, R] | Implementation[P, R]], Implementation[P, R]] - ): + ) -> Implementation[P, R] | Callable[[Any], Implementation[P, R]]: """Add an alternative implementation.""" if isinstance(implementation, _Undefined): return self.alternatives.add(default=default) @@ -455,18 +513,24 @@ def reference( @overload def reference( - implementation: Callable[P, R], *, default: bool = False + implementation: Callable[P, R] | _Descriptor, *, default: bool = False ) -> Alternatives[P, R]: ... +@overload +def reference( + implementation: Any, *, default: bool = False +) -> Alternatives[Any, Any]: ... + + def reference( - implementation: Callable[P, R] | _Undefined = _UNDEFINED_VALUE, + implementation: Any = _UNDEFINED_VALUE, *, default: bool = False, -) -> Alternatives[P, R] | Callable[[Callable[P, R]], Alternatives[P, R]]: +) -> Alternatives[Any, Any] | Callable[[Any], Alternatives[Any, Any]]: if isinstance(implementation, _Undefined): - def inner(f: Callable[P, R]) -> Alternatives[P, R]: + def inner(f: Any) -> Alternatives[Any, Any]: """Add the reference implementation to the alternatives""" return Alternatives(f, default=default) diff --git a/docs/pytest.rst b/docs/pytest.rst index e57e9a4..75af88c 100644 --- a/docs/pytest.rst +++ b/docs/pytest.rst @@ -85,6 +85,37 @@ Use :meth:`alternative.Alternatives.pytest_parametrize` with Pytest generates readable parameter names from the underlying function names. +Testing Methods +--------------- + +Method alternatives can be tested with the same helpers by passing an explicit +instance to the parametrized implementation: + +.. code-block:: python + + class Counter: + def __init__(self, value: int): + self.value = value + + @alternative.reference + def total(self) -> int: + return int(str(self.value)) + + @total.add(default=True) + def total_fast(self) -> int: + return self.value + + + @Counter.total.pytest_parametrize_pairs() + def test_totals_are_equivalent(reference, implementation): + """Every method implementation returns the same total.""" + counter = Counter(3) + assert implementation(counter) == reference(counter) + +The pytest helpers parametrize callables. They do not change the selected +default implementation for each test parameter; the library still keeps the +active implementation stable once it has been used. + Collection Order ---------------- diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 8e61096..6623d62 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -96,3 +96,45 @@ alternatives set. This makes chained registration convenient: return int(text) The three implementations still belong to the same alternatives set. + +Use Methods +----------- + +``alternative`` follows Python descriptor binding rules, so the same decorator +also works on methods. Put ``@alternative.reference`` outside +``@classmethod`` or ``@staticmethod`` when those decorators are needed: + +.. code-block:: python + + class Parser: + @alternative.reference + def parse(self, value: str) -> int: + return int(value.strip()) + + @parse.add(default=True) + def parse_fast(self, value: str) -> int: + return int(value) + + @alternative.reference + @classmethod + def from_text(cls, value: str) -> "Parser": + return cls(value.strip()) + + @from_text.add(default=True) + @classmethod + def from_text_fast(cls, value: str) -> "Parser": + return cls(value) + + @alternative.reference + @staticmethod + def is_valid(value: str) -> bool: + return value.strip().isdigit() + + @is_valid.add(default=True) + @staticmethod + def is_valid_fast(value: str) -> bool: + return value.isdigit() + +Calling through an instance or class binds ``self`` and ``cls`` normally. Direct +alternative implementations also bind normally, so ``parser.parse_fast("1")`` +or ``Parser.from_text_fast("1")`` call that implementation directly. diff --git a/test_alternative.py b/test_alternative.py index adb096b..5141643 100644 --- a/test_alternative.py +++ b/test_alternative.py @@ -341,3 +341,107 @@ def alt(): repr(alt) == "Implementation(test_implementation_repr_without_label..alt)" ) + + +def test_instance_method_binding(): + """Alternatives bind instance methods through descriptor access.""" + + class Calculator: + def __init__(self, offset: int): + self.offset = offset + + @alternative.reference + def add(self, value: int) -> tuple[str, int]: + return ("reference", self.offset + value) + + @add.add(default=True) + def add_default(self, value: int) -> tuple[str, int]: + return ("default", self.offset + value) + + @add.add + def add_extra(self, value: int) -> tuple[str, int]: + return ("extra", self.offset + value) + + calculator = Calculator(10) + + assert calculator.add(5) == ("default", 15) + assert calculator.add_default(5) == ("default", 15) + assert calculator.add_extra(5) == ("extra", 15) + assert Calculator.add(calculator, 5) == ("default", 15) + assert Calculator.add_extra(calculator, 5) == ("extra", 15) + assert calculator.add_extra.alternatives is Calculator.__dict__["add"] + assert Calculator.__dict__["add"].__get__(calculator)(5) == ("default", 15) + + +def test_classmethod_binding(): + """Alternatives bind classmethod implementations to the owner class.""" + + class Factory: + marker = "Factory" + + @alternative.reference + @classmethod + def build(cls, value: str) -> tuple[str, str, str]: + return ("reference", cls.marker, value) + + @build.add(default=True) + @classmethod + def build_default(cls, value: str) -> tuple[str, str, str]: + return ("default", cls.marker, value) + + @build.add + @classmethod + def build_extra(cls, value: str) -> tuple[str, str, str]: + return ("extra", cls.marker, value) + + class ChildFactory(Factory): + marker = "ChildFactory" + + assert Factory.build("a") == ("default", "Factory", "a") + assert Factory().build("a") == ("default", "Factory", "a") + assert Factory.build_default("a") == ("default", "Factory", "a") + assert Factory.build_extra("a") == ("extra", "Factory", "a") + assert ChildFactory.build("a") == ("default", "ChildFactory", "a") + assert ChildFactory.build_extra("a") == ("extra", "ChildFactory", "a") + + +def test_staticmethod_binding(): + """Alternatives preserve staticmethod binding from class and instance access.""" + + class Parser: + @alternative.reference + @staticmethod + def parse(value: str) -> tuple[str, str]: + return ("reference", value) + + @parse.add(default=True) + @staticmethod + def parse_default(value: str) -> tuple[str, str]: + return ("default", value) + + @parse.add + @staticmethod + def parse_extra(value: str) -> tuple[str, str]: + return ("extra", value) + + assert Parser.parse("x") == ("default", "x") + assert Parser().parse("x") == ("default", "x") + assert Parser.parse_default("x") == ("default", "x") + assert Parser.parse_extra("x") == ("extra", "x") + + +def test_bound_attribute_access_does_not_freeze_implementations(): + """Accessing a method alternative does not freeze registrations before invocation.""" + + class Counter: + @alternative.reference + def value(self) -> int: + return 1 + + bound_value = Counter().value + + @Counter.value.add(default=True) + def value_default(self) -> int: + return 2 + + assert bound_value() == 2 diff --git a/test_pytest_util.py b/test_pytest_util.py index ca51de2..ed55596 100644 --- a/test_pytest_util.py +++ b/test_pytest_util.py @@ -65,6 +65,45 @@ def extra_impl(): ] +def test_select_parametrize_implementations_with_implicit_default(): + """Only-default parametrization includes the wrapper when the reference default is implicit.""" + + @alternative.reference + def reference_impl(): + return 1 + + @reference_impl.add + def extra_impl(): + return 2 + + selected = reference_impl._select_parametrize_implementations( # pyrefly: ignore + only_default=True + ) + + assert selected == [ + reference_impl.reference.implementation, + reference_impl.callable, + ] + + +def test_select_parametrize_implementations_with_explicit_reference_default(): + """Only-default parametrization does not duplicate an explicitly defaulted reference.""" + + @alternative.reference(default=True) + def reference_impl(): + return 1 + + @reference_impl.add + def extra_impl(): + return 2 + + selected = reference_impl._select_parametrize_implementations( # pyrefly: ignore + only_default=True + ) + + assert selected == [reference_impl.reference.implementation] + + @pytest.mark.parametrize("only_default", [False, True]) @pytest.mark.parametrize("double_reference", [False, True]) def test_pytest_parametrize_pairs_signature_and_parameters( From b1c6470e37ac84a9d81e6f69f465833aa30a7a48 Mon Sep 17 00:00:00 2001 From: Oliver Bristow Date: Sat, 9 May 2026 22:07:14 +0100 Subject: [PATCH 2/8] Fix the typing --- AGENTS.md | 1 + alternative.py | 461 +++++++++++++++++++++++++++++++++++--------- test_alternative.py | 49 ++++- test_pytest_util.py | 28 ++- 4 files changed, 434 insertions(+), 105 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 382899b..45b4bde 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -10,5 +10,6 @@ The repository defines testing via GitHub actions. When contributing: * `uv run --group=docs sphinx-build --fail-on-warning --keep-going --builder=html docs /tmp/alternative-docs-html` * Format code with `uv run --dev ruff format .` before committing. * Keep the documentation in `docs/` up to date with user-facing behavior, API, and workflow changes. Documentation must compile without warnings. +* Keep `alternative.py` strictly typed: do not use `typing.Any` or `Any`, and do not add mypy or pyrefly suppression comments. Fix the annotations so public decorators remain transparent to type checkers and IDEs. * Any change to branching paths in `alternative.py` must be followed by a branch coverage run and review for material missing runtime coverage using `uv run --dev pytest --cov=alternative --cov-branch --cov-report=term-missing:skip-covered`. * Name tests and functions in `snake_case` and give them triple-quoted docstrings similar to the current codebase. diff --git a/alternative.py b/alternative.py index 322285a..459dd40 100644 --- a/alternative.py +++ b/alternative.py @@ -5,8 +5,8 @@ import os from functools import wraps, lru_cache from typing import ( - Any, Callable, + Concatenate, Final, Generic, ParamSpec, @@ -45,18 +45,52 @@ class _SupportsLessThan(Protocol): def __lt__(self, other: object, /) -> bool: ... -class _Descriptor(Protocol): - def __get__( - self, instance: object | None, owner: type[Any] | None = None, / - ) -> Any: ... - - _UNDEFINED_VALUE: Final = _Undefined() P = ParamSpec("P") R = TypeVar("R") +R_co = TypeVar("R_co", covariant=True) M = TypeVar("M") -F = TypeVar("F", bound=Callable[..., Any]) +S = TypeVar("S") +Owner = TypeVar("Owner") +F = TypeVar("F", bound=Callable[..., object]) + + +class _NoReceiver: + """Marker type for descriptors that bind themselves before user calls.""" + + +class _TypedDescriptor(Protocol[P, R_co]): + def __get__( + self, instance: object | None, owner: type[object] | None = None, / + ) -> Callable[P, R_co]: ... + + +class _ReferenceDecorator(Protocol): + @overload + def __call__( + self, implementation: classmethod[Owner, P, R], / + ) -> Alternatives[_NoReceiver, P, R]: ... + + @overload + def __call__( + self, implementation: staticmethod[P, R], / + ) -> Alternatives[_NoReceiver, P, R]: ... + + @overload + def __call__( + self, implementation: Callable[Concatenate[type[Owner], P], R], / + ) -> Alternatives[_NoReceiver, P, R]: ... + + @overload + def __call__( + self, implementation: Callable[Concatenate[S, P], R], / + ) -> Alternatives[S, P, R]: ... + + @overload + def __call__( + self, implementation: Callable[P, R], / + ) -> Alternatives[_NoReceiver, P, R]: ... class AlternativeError(Exception): @@ -115,69 +149,155 @@ def _maybe_get_caller_path() -> str | None: def _bind_implementation( - implementation: Any, + implementation: Callable[..., R] | _TypedDescriptor[P, R], instance: object | None, - owner: type[Any] | None, -) -> Callable[P, R]: + owner: type[object] | None, +) -> Callable[..., R]: """Bind an implementation using descriptor semantics when available.""" if owner is None and instance is not None: owner = type(instance) - descriptor_get = getattr(implementation, "__get__", None) - if descriptor_get is not None and owner is not None: - return cast(Callable[P, R], descriptor_get(instance, owner)) - return cast(Callable[P, R], implementation) + if hasattr(implementation, "__get__") and owner is not None: + return cast(_TypedDescriptor[P, R], implementation).__get__(instance, owner) + return cast(Callable[..., R], implementation) @dataclasses.dataclass(frozen=True) -class _BoundAlternatives(Generic[P, R]): - alternatives: Alternatives[P, R] +class _BoundAlternatives(Generic[S, P, R]): + alternatives: Alternatives[S, P, R] instance: object | None - owner: type[Any] | None + owner: type[object] | None def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: - implementation: Callable[P, R] = _bind_implementation( + implementation: Callable[..., R] = _bind_implementation( self.alternatives.callable, self.instance, self.owner ) return implementation(*args, **kwargs) - def __getattr__(self, name: str) -> Any: - return getattr(self.alternatives, name) + @overload + def add( + self: _BoundAlternatives[_NoReceiver, P, R], + implementation: Callable[P, R] + | Callable[Concatenate[type[Owner], P], R] + | classmethod[Owner, P, R] + | staticmethod[P, R], + *, + default: bool = False, + ) -> Implementation[_NoReceiver, P, R]: ... + + @overload + def add( + self, + implementation: Callable[Concatenate[S, P], R] | Implementation[S, P, R], + *, + default: bool = False, + ) -> Implementation[S, P, R]: ... + + @overload + def add( + self, *, default: bool = False + ) -> Callable[ + [Callable[..., R] | classmethod[Owner, P, R] | staticmethod[P, R]], + Implementation[S, P, R], + ]: ... + + def add( + self, + implementation: object = _UNDEFINED_VALUE, + *, + default: bool = False, + ) -> object: + add = cast(Callable[..., object], self.alternatives.add) + return add(implementation, default=default) + + @property + def implementations(self) -> list[Implementation[S, P, R]]: + return self.alternatives.implementations + + @property + def reference(self) -> Implementation[S, P, R]: + return self.alternatives.reference + + @property + def callable(self) -> Callable[..., R] | _TypedDescriptor[P, R]: + return self.alternatives.callable @dataclasses.dataclass(frozen=True) -class _BoundImplementation(Generic[P, R]): - implementation: Implementation[P, R] +class _BoundImplementation(Generic[S, P, R]): + implementation: Implementation[S, P, R] instance: object | None - owner: type[Any] | None + owner: type[object] | None def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: - implementation: Callable[P, R] = _bind_implementation( + implementation: Callable[..., R] = _bind_implementation( self.implementation.implementation, self.instance, self.owner ) return implementation(*args, **kwargs) - def __getattr__(self, name: str) -> Any: - return getattr(self.implementation, name) + @property + def alternatives(self) -> Alternatives[S, P, R]: + return self.implementation.alternatives + @overload + def add( + self: _BoundImplementation[_NoReceiver, P, R], + implementation: Callable[P, R] + | Callable[Concatenate[type[Owner], P], R] + | classmethod[Owner, P, R] + | staticmethod[P, R], + *, + default: bool = False, + ) -> Implementation[_NoReceiver, P, R]: ... -class Alternatives(Generic[P, R]): - def __init__(self, implementation: Any, *, default: bool = False): + @overload + def add( + self, + implementation: Callable[Concatenate[S, P], R] | Implementation[S, P, R], + *, + default: bool = False, + ) -> Implementation[S, P, R]: ... + + @overload + def add( + self, *, default: bool = False + ) -> Callable[ + [Callable[..., R] | classmethod[Owner, P, R] | staticmethod[P, R]], + Implementation[S, P, R], + ]: ... + + def add( + self, + implementation: object = _UNDEFINED_VALUE, + *, + default: bool = False, + ) -> object: + add = cast(Callable[..., object], self.implementation.add) + return add(implementation, default=default) + + +class Alternatives(Generic[S, P, R]): + def __init__( + self, + implementation: Callable[..., R] | _TypedDescriptor[P, R], + *, + default: bool = False, + ): imp = Implementation(self, implementation, label=_maybe_get_caller_path()) self.reference = imp # tracks the active implementation - self._default: Implementation[P, R] | None = None + self._default: Implementation[S, P, R] | None = None self._debug_default: str | None = None self._invoked = False self._debug_invoked_site: str | None = None # tracks the use of the set should be self._enumerated = False - self._callable: Any | None = None + self._callable: Callable[..., R] | _TypedDescriptor[P, R] | None = None self._debug_callable_used: str | None = None # beware the order of this depends on the sequence of imports, so may vary between entrypoints - self._implementations: list[Implementation[P, R]] = [] + self._implementations: list[Implementation[S, P, R]] = [] self._implementations_used: bool = False """indicates if the list of implementations has been used though the external API""" self._debug_implementations_used: str | None = None @@ -185,22 +305,37 @@ def __init__(self, implementation: Any, *, default: bool = False): @overload def add( - self, *, default: bool = False - ) -> Callable[[Any], Implementation[P, R]]: ... + self: Alternatives[_NoReceiver, P, R], + implementation: Callable[P, R] + | Callable[Concatenate[type[Owner], P], R] + | classmethod[Owner, P, R] + | staticmethod[P, R], + *, + default: bool = False, + ) -> Implementation[_NoReceiver, P, R]: ... + @overload def add( self, - implementation: Any, + implementation: Callable[Concatenate[S, P], R] | Implementation[S, P, R], *, default: bool = False, - ) -> Implementation[P, R]: ... + ) -> Implementation[S, P, R]: ... + + @overload + def add( + self, *, default: bool = False + ) -> Callable[ + [Callable[..., R] | classmethod[Owner, P, R] | staticmethod[P, R]], + Implementation[S, P, R], + ]: ... def add( self, - implementation: Any = _UNDEFINED_VALUE, + implementation: object = _UNDEFINED_VALUE, *, default: bool = False, - ) -> Implementation[P, R] | Callable[[Any], Implementation[P, R]]: + ) -> object: if self._implementations_used: # avoid surprises from implementation changes after selection/inspection if DEBUG: @@ -212,15 +347,20 @@ def add( if isinstance(implementation, _Undefined): def wrapper( - implementation: Any, - ) -> Implementation[P, R]: - return self.add(implementation, default=default) + implementation: Callable[..., R] | _TypedDescriptor[P, R], + ) -> Implementation[S, P, R]: + add = cast(Callable[..., Implementation[S, P, R]], self.add) + return add(implementation, default=default) return wrapper label = _maybe_get_caller_path() if not isinstance(implementation, Implementation): - imp = Implementation(self, implementation, label=label) + imp = Implementation( + self, + cast(Callable[..., R] | _TypedDescriptor[P, R], implementation), + label=label, + ) elif implementation.alternatives is not self: raise CrossAlternativesImplementationError( f"Cannot add {implementation!r} to {self.reference!r}; " @@ -246,7 +386,7 @@ def wrapper( return imp @property - def callable(self) -> Any: + def callable(self) -> Callable[..., R] | _TypedDescriptor[P, R]: """Return the active implementation. Setting the default implementation is disabled after this is accessed.""" @@ -262,24 +402,80 @@ def callable(self) -> Any: return self._callable @property - def implementations(self) -> list[Implementation[P, R]]: + def implementations(self) -> list[Implementation[S, P, R]]: if not self._implementations_used: self._implementations_used = True self._debug_implementations_used = _maybe_get_caller_path() return self._implementations - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: - implementation: Callable[P, R] = _bind_implementation(self.callable, None, None) - return implementation(*args, **kwargs) + @overload + def __call__( + self: Alternatives[_NoReceiver, P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: ... + + @overload + def __call__(self, receiver: S, *args: P.args, **kwargs: P.kwargs) -> R: ... + + def __call__( + self, + receiver: object = _UNDEFINED_VALUE, + *args: object, + **kwargs: object, + ) -> R: + implementation: Callable[..., R] = _bind_implementation( + self.callable, None, None + ) + if isinstance(receiver, _Undefined): + return implementation(**kwargs) + return implementation(receiver, *args, **kwargs) + + @overload + def __get__( + self, instance: None, owner: type[object] | None = None + ) -> Alternatives[S, P, R]: ... + + @overload + def __get__( + self: Alternatives[_NoReceiver, P, R], + instance: object, + owner: type[object] | None = None, + ) -> _BoundAlternatives[_NoReceiver, P, R]: ... + @overload def __get__( - self, instance: object | None, owner: type[Any] | None = None - ) -> _BoundAlternatives[P, R]: + self, instance: S, owner: type[object] | None = None + ) -> _BoundAlternatives[S, P, R]: ... + + @overload + def __get__( + self, instance: object, owner: type[object] | None = None + ) -> Callable[Concatenate[S, P], R]: ... + + def __get__( + self, instance: object | None, owner: type[object] | None = None + ) -> object: + if instance is None and not self._binds_on_class_access(): + return self return _BoundAlternatives(self, instance, owner) + def _binds_on_class_access(self) -> bool: + """Return True when class access needs descriptor binding before calling.""" + implementation = ( + self._default.implementation + if self._default is not None + else self.reference.implementation + ) + return isinstance(implementation, classmethod) + def measure( - self, /, operator: Callable[[R], M], *args: P.args, **kwargs: P.kwargs - ) -> dict[Implementation[P, R], M]: + self, + /, + operator: Callable[[R], M], + *args: object, + **kwargs: object, + ) -> dict[Implementation[S, P, R], M]: """Invoke each implementation with the given parameters, then evaluate their results with the operator. This is useful when comparing implementations that have different results, which can be compared by some cost. @@ -289,7 +485,8 @@ def measure( __lt__(a,b) is called); otherwise they are returned in the order of the implementations. """ result = { - i: operator(i.implementation(*args, **kwargs)) for i in self.implementations + i: operator(cast(Callable[..., R], i)(*args, **kwargs)) + for i in self.implementations } try: # try to sort the dictionary by the measurements @@ -297,7 +494,9 @@ def measure( sorted( result.items(), key=cast( - Callable[[tuple[Implementation[P, R], M]], _SupportsLessThan], + Callable[ + [tuple[Implementation[S, P, R], M]], _SupportsLessThan + ], lambda x: cast(_SupportsLessThan, x[1]), ), ) @@ -344,7 +543,7 @@ def decorator(f: F) -> F: @pytest.mark.parametrize("implementation", implementations) @wraps(test) - def inner(*args: Any, **kwargs: Any) -> Any: + def inner(*args: object, **kwargs: object) -> object: return test(*args, **kwargs) return cast(F, inner) @@ -400,8 +599,10 @@ def decorator(f: F) -> F: return decorator reference_implementation = cast( - Callable[P, R], - lru_cache(maxsize=n_cache)(self.reference.implementation), + Callable[..., R], + lru_cache(maxsize=n_cache)( + cast(Callable[..., R], self.reference.implementation) + ), ) implementations = self._select_parametrize_pairs( @@ -413,19 +614,21 @@ def decorator(f: F) -> F: @pytest.mark.parametrize("reference", [reference_implementation]) @pytest.mark.parametrize("implementation", implementations) @wraps(test) - def inner(*args: Any, **kwargs: Any) -> Any: + def inner(*args: object, **kwargs: object) -> object: return test(*args, **kwargs) return cast(F, inner) def _select_parametrize_implementations( self, *, only_default: bool - ) -> list[Callable[P, R]]: + ) -> list[Callable[..., R] | _TypedDescriptor[P, R]]: """Return implementation callables used for ``pytest_parametrize``.""" if only_default: reference_implementation = self.reference.implementation default_implementation = self.callable - implementations = [reference_implementation] + implementations: list[Callable[..., R] | _TypedDescriptor[P, R]] = [ + reference_implementation + ] if default_implementation is not reference_implementation: implementations.append(default_implementation) return implementations @@ -434,12 +637,13 @@ def _select_parametrize_implementations( def _select_parametrize_pairs( self, *, - reference_implementation: Callable[P, R], + reference_implementation: Callable[..., R], only_default: bool, double_reference: bool, - ) -> list[Callable[P, R]]: + ) -> list[Callable[..., R] | _TypedDescriptor[P, R]]: """Return implementation callables used for ``pytest_parametrize_pairs``.""" # use underlying functions so pytest can generate readable IDs. + implementations: list[Callable[..., R] | _TypedDescriptor[P, R]] if only_default: implementations = [self.callable] if double_reference and self.callable is not self.reference.implementation: @@ -453,9 +657,9 @@ def _select_parametrize_pairs( @dataclasses.dataclass(unsafe_hash=True) -class Implementation(Generic[P, R]): - alternatives: Alternatives[P, R] - implementation: Any +class Implementation(Generic[S, P, R]): + alternatives: Alternatives[S, P, R] + implementation: Callable[..., R] | _TypedDescriptor[P, R] label: str | None = None def __post_init__(self): @@ -470,70 +674,151 @@ def __repr__(self) -> str: return f"Implementation({implementation_name}, label={self.label!r})" return f"Implementation({implementation_name})" - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: - implementation: Callable[P, R] = _bind_implementation( + @overload + def __call__( + self: Implementation[_NoReceiver, P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: ... + + @overload + def __call__(self, receiver: S, *args: P.args, **kwargs: P.kwargs) -> R: ... + + def __call__( + self, + receiver: object = _UNDEFINED_VALUE, + *args: object, + **kwargs: object, + ) -> R: + implementation: Callable[..., R] = _bind_implementation( self.implementation, None, None ) - return implementation(*args, **kwargs) + if isinstance(receiver, _Undefined): + return implementation(**kwargs) + return implementation(receiver, *args, **kwargs) + + @overload + def __get__( + self, instance: None, owner: type[object] | None = None + ) -> Implementation[S, P, R]: ... + + @overload + def __get__( + self: Implementation[_NoReceiver, P, R], + instance: object, + owner: type[object] | None = None, + ) -> _BoundImplementation[_NoReceiver, P, R]: ... + + @overload + def __get__( + self, instance: S, owner: type[object] | None = None + ) -> _BoundImplementation[S, P, R]: ... + + @overload + def __get__( + self, instance: object, owner: type[object] | None = None + ) -> Callable[Concatenate[S, P], R]: ... def __get__( - self, instance: object | None, owner: type[Any] | None = None - ) -> _BoundImplementation[P, R]: + self, instance: object | None, owner: type[object] | None = None + ) -> object: + if instance is None and not isinstance(self.implementation, classmethod): + return self return _BoundImplementation(self, instance, owner) @overload def add( - self, *, default: bool = False - ) -> Callable[[Any], Implementation[P, R]]: ... + self: Implementation[_NoReceiver, P, R], + implementation: Callable[P, R] + | Callable[Concatenate[type[Owner], P], R] + | classmethod[Owner, P, R] + | staticmethod[P, R], + *, + default: bool = False, + ) -> Implementation[_NoReceiver, P, R]: ... + @overload def add( self, - implementation: Any, + implementation: Callable[Concatenate[S, P], R] | Implementation[S, P, R], *, default: bool = False, - ) -> Implementation[P, R]: ... + ) -> Implementation[S, P, R]: ... + + @overload + def add( + self, *, default: bool = False + ) -> Callable[ + [Callable[..., R] | classmethod[Owner, P, R] | staticmethod[P, R]], + Implementation[S, P, R], + ]: ... def add( self, - implementation: Any = _UNDEFINED_VALUE, + implementation: object = _UNDEFINED_VALUE, *, default: bool = False, - ) -> Implementation[P, R] | Callable[[Any], Implementation[P, R]]: + ) -> object: """Add an alternative implementation.""" if isinstance(implementation, _Undefined): return self.alternatives.add(default=default) - return self.alternatives.add(implementation, default=default) + add = cast(Callable[..., object], self.alternatives.add) + return add(implementation, default=default) + + +@overload +def reference(*, default: bool = False) -> _ReferenceDecorator: ... @overload def reference( - *, default: bool = False -) -> Callable[[Callable[P, R]], Alternatives[P, R]]: ... + implementation: classmethod[Owner, P, R], *, default: bool = False +) -> Alternatives[_NoReceiver, P, R]: ... @overload def reference( - implementation: Callable[P, R] | _Descriptor, *, default: bool = False -) -> Alternatives[P, R]: ... + implementation: staticmethod[P, R], *, default: bool = False +) -> Alternatives[_NoReceiver, P, R]: ... @overload def reference( - implementation: Any, *, default: bool = False -) -> Alternatives[Any, Any]: ... + implementation: Callable[Concatenate[type[Owner], P], R], + *, + default: bool = False, +) -> Alternatives[_NoReceiver, P, R]: ... +@overload def reference( - implementation: Any = _UNDEFINED_VALUE, + implementation: Callable[Concatenate[S, P], R], *, default: bool = False +) -> Alternatives[S, P, R]: ... + + +@overload +def reference( + implementation: Callable[P, R], *, default: bool = False +) -> Alternatives[_NoReceiver, P, R]: ... + + +def reference( + implementation: object = _UNDEFINED_VALUE, *, default: bool = False, -) -> Alternatives[Any, Any] | Callable[[Any], Alternatives[Any, Any]]: +) -> object: if isinstance(implementation, _Undefined): - def inner(f: Any) -> Alternatives[Any, Any]: + def inner(f: object) -> object: """Add the reference implementation to the alternatives""" - return Alternatives(f, default=default) + return Alternatives( + cast(Callable[..., object] | _TypedDescriptor[..., object], f), + default=default, + ) - return inner + return cast(_ReferenceDecorator, inner) else: - return Alternatives(implementation, default=default) + return Alternatives( + cast(Callable[..., object] | _TypedDescriptor[..., object], implementation), + default=default, + ) diff --git a/test_alternative.py b/test_alternative.py index 5141643..5f15432 100644 --- a/test_alternative.py +++ b/test_alternative.py @@ -1,10 +1,11 @@ import re +from inspect import signature +from typing import Callable, cast from unittest.mock import MagicMock import pytest import alternative -from inspect import signature def imp_for_cmp(imp: alternative.Implementation | None) -> dict | None: @@ -39,9 +40,9 @@ def alt3(): def test_coupled_signatures(): """The signatures of reference, Alternative.add, and Implementation.add are aligned.""" - ref_sig = signature(alternative.reference) # pyrefly: ignore - alt_sig = signature(alternative.Alternatives.add) # pyrefly: ignore - imp_sig = signature(alternative.Implementation.add) # pyrefly: ignore + ref_sig = signature(cast(Callable[..., object], alternative.reference)) + alt_sig = signature(cast(Callable[..., object], alternative.Alternatives.add)) + imp_sig = signature(cast(Callable[..., object], alternative.Implementation.add)) assert alt_sig.parameters == imp_sig.parameters # skip the self-parameter to give matching signatures assert ( @@ -343,7 +344,7 @@ def alt(): ) -def test_instance_method_binding(): +def test_instance_method_binding() -> None: """Alternatives bind instance methods through descriptor access.""" class Calculator: @@ -370,10 +371,14 @@ def add_extra(self, value: int) -> tuple[str, int]: assert Calculator.add(calculator, 5) == ("default", 15) assert Calculator.add_extra(calculator, 5) == ("extra", 15) assert calculator.add_extra.alternatives is Calculator.__dict__["add"] - assert Calculator.__dict__["add"].__get__(calculator)(5) == ("default", 15) + descriptor = cast( + alternative.Alternatives[Calculator, [int], tuple[str, int]], + Calculator.__dict__["add"], + ) + assert descriptor.__get__(calculator)(5) == ("default", 15) -def test_classmethod_binding(): +def test_classmethod_binding() -> None: """Alternatives bind classmethod implementations to the owner class.""" class Factory: @@ -405,7 +410,7 @@ class ChildFactory(Factory): assert ChildFactory.build_extra("a") == ("extra", "ChildFactory", "a") -def test_staticmethod_binding(): +def test_staticmethod_binding() -> None: """Alternatives preserve staticmethod binding from class and instance access.""" class Parser: @@ -430,7 +435,7 @@ def parse_extra(value: str) -> tuple[str, str]: assert Parser.parse_extra("x") == ("extra", "x") -def test_bound_attribute_access_does_not_freeze_implementations(): +def test_bound_attribute_access_does_not_freeze_implementations() -> None: """Accessing a method alternative does not freeze registrations before invocation.""" class Counter: @@ -445,3 +450,29 @@ def value_default(self) -> int: return 2 assert bound_value() == 2 + + +def test_bound_method_registration_delegates_to_alternatives() -> None: + """Bound alternatives expose registration and metadata without dynamic attribute typing.""" + + class Counter: + @alternative.reference + def value(self) -> int: + return 1 + + counter = Counter() + + @counter.value.add(default=True) + def value_default(self) -> int: + return 2 + + @value_default.__get__(counter, type(counter)).add + def value_extra(self) -> int: + return 3 + + assert counter.value.implementations == Counter.__dict__["value"].implementations + assert counter.value.reference is Counter.__dict__["value"].reference + assert counter.value.callable is Counter.__dict__["value"].callable + assert counter.value() == 2 + assert value_default.__get__(counter, type(counter))() == 2 + assert value_extra(counter) == 3 diff --git a/test_pytest_util.py b/test_pytest_util.py index ed55596..a9c310c 100644 --- a/test_pytest_util.py +++ b/test_pytest_util.py @@ -1,4 +1,5 @@ import inspect +from collections.abc import Callable import alternative import pytest @@ -50,9 +51,14 @@ def default_impl(): def extra_impl(): return 3 - selected = reference_impl._select_parametrize_implementations( # pyrefly: ignore - only_default=only_default + def parametrized(implementation: Callable[[], int]) -> None: + """Placeholder test used to inspect pytest parametrization values.""" + assert implementation() in {1, 2, 3} + + decorated = reference_impl.pytest_parametrize( + parametrized, only_default=only_default ) + selected = _parametrize_values(decorated, "implementation")[0] default_callable = default_impl.implementation extra_callable = extra_impl.implementation if only_default: @@ -76,9 +82,12 @@ def reference_impl(): def extra_impl(): return 2 - selected = reference_impl._select_parametrize_implementations( # pyrefly: ignore - only_default=True - ) + def parametrized(implementation: Callable[[], int]) -> None: + """Placeholder test used to inspect pytest parametrization values.""" + assert implementation() in {1, 2} + + decorated = reference_impl.pytest_parametrize(parametrized, only_default=True) + selected = _parametrize_values(decorated, "implementation")[0] assert selected == [ reference_impl.reference.implementation, @@ -97,9 +106,12 @@ def reference_impl(): def extra_impl(): return 2 - selected = reference_impl._select_parametrize_implementations( # pyrefly: ignore - only_default=True - ) + def parametrized(implementation: Callable[[], int]) -> None: + """Placeholder test used to inspect pytest parametrization values.""" + assert implementation() in {1, 2} + + decorated = reference_impl.pytest_parametrize(parametrized, only_default=True) + selected = _parametrize_values(decorated, "implementation")[0] assert selected == [reference_impl.reference.implementation] From 488a733aa0d958cf58b8bacab18daa03debb2a42 Mon Sep 17 00:00:00 2001 From: Oliver Bristow Date: Sat, 9 May 2026 22:13:10 +0100 Subject: [PATCH 3/8] add revel-type tests for decorator transparency --- pyproject.toml | 1 + typing_tests/type_probes.py | 83 +++++++++++++++++++++++++++++++++++++ uv.lock | 2 + 3 files changed, 86 insertions(+) create mode 100644 typing_tests/type_probes.py diff --git a/pyproject.toml b/pyproject.toml index f3a981b..02f2d74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dev-dependencies = [ "mypy>=2.0.0", "pyrefly>=0.64.1", "ruff>=0.15.12", + "typing-extensions>=4.15.0", ] [dependency-groups] diff --git a/typing_tests/type_probes.py b/typing_tests/type_probes.py new file mode 100644 index 0000000..df4e50e --- /dev/null +++ b/typing_tests/type_probes.py @@ -0,0 +1,83 @@ +"""Type-level probes for decorator transparency. + +This module is intentionally not a pytest test. It is exercised by the CI +type-checking stages, which run mypy and pyrefly over the whole repository. +""" + +from collections.abc import Callable + +from typing_extensions import assert_type + +import alternative + + +@alternative.reference +def normal_function(count: int, label: str) -> str: + """Return a labelled value repeated a requested number of times.""" + return label * count + + +assert_type(normal_function(2, "a"), str) +normal_function_callable: Callable[[int, str], str] = normal_function + + +class NormalMethods: + """Container for instance method typing probes.""" + + @alternative.reference + def method(self, count: int, label: str) -> str: + """Return a labelled value repeated a requested number of times.""" + return label * count + + @method.add + def method_extra(self, count: int, label: str) -> str: + """Return an upper-case labelled value repeated a requested number of times.""" + return label.upper() * count + + +normal_methods = NormalMethods() +assert_type(normal_methods.method(2, "a"), str) +assert_type(normal_methods.method_extra(2, "a"), str) +bound_method_callable: Callable[[int, str], str] = normal_methods.method +unbound_method_callable: Callable[[NormalMethods, int, str], str] = NormalMethods.method + + +class DescriptorMethods: + """Container for classmethod and staticmethod typing probes.""" + + @alternative.reference + @classmethod + def build(cls, count: int, label: str) -> str: + """Return a labelled value repeated a requested number of times.""" + return label * count + + @build.add + @classmethod + def build_extra(cls, count: int, label: str) -> str: + """Return an upper-case labelled value repeated a requested number of times.""" + return label.upper() * count + + @alternative.reference + @staticmethod + def parse(count: int, label: str) -> str: + """Return a labelled value repeated a requested number of times.""" + return label * count + + @parse.add + @staticmethod + def parse_extra(count: int, label: str) -> str: + """Return an upper-case labelled value repeated a requested number of times.""" + return label.upper() * count + + +assert_type(DescriptorMethods.build(2, "a"), str) +assert_type(DescriptorMethods().build(2, "a"), str) +assert_type(DescriptorMethods.build_extra(2, "a"), str) +classmethod_callable: Callable[[int, str], str] = DescriptorMethods.build +bound_classmethod_callable: Callable[[int, str], str] = DescriptorMethods().build + +assert_type(DescriptorMethods.parse(2, "a"), str) +assert_type(DescriptorMethods().parse(2, "a"), str) +assert_type(DescriptorMethods.parse_extra(2, "a"), str) +staticmethod_callable: Callable[[int, str], str] = DescriptorMethods.parse +bound_staticmethod_callable: Callable[[int, str], str] = DescriptorMethods().parse diff --git a/uv.lock b/uv.lock index b7b14cd..1c78cc9 100644 --- a/uv.lock +++ b/uv.lock @@ -34,6 +34,7 @@ dev = [ { name = "pytest-benchmark" }, { name = "pytest-cov" }, { name = "ruff" }, + { name = "typing-extensions" }, ] docs = [ { name = "sphinx" }, @@ -53,6 +54,7 @@ dev = [ { name = "mypy", specifier = ">=2.0.0" }, { name = "pyrefly", specifier = ">=0.64.1" }, { name = "ruff", specifier = ">=0.15.12" }, + { name = "typing-extensions", specifier = ">=4.15.0" }, ] docs = [ { name = "sphinx", specifier = ">=8.1.3" }, From db54d03414ceef93f4ad65ab0762de053d22dc41 Mon Sep 17 00:00:00 2001 From: Oliver Bristow Date: Sat, 9 May 2026 23:11:58 +0100 Subject: [PATCH 4/8] Use less expressive typing to help IDEs (PyCharm) analyse types --- AGENTS.md | 1 + alternative.py | 399 ++++++++++++++++++++++++++------- scripts/pycharm-type-probes.sh | 123 ++++++++++ 3 files changed, 442 insertions(+), 81 deletions(-) create mode 100755 scripts/pycharm-type-probes.sh diff --git a/AGENTS.md b/AGENTS.md index 45b4bde..de110ea 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -11,5 +11,6 @@ The repository defines testing via GitHub actions. When contributing: * Format code with `uv run --dev ruff format .` before committing. * Keep the documentation in `docs/` up to date with user-facing behavior, API, and workflow changes. Documentation must compile without warnings. * Keep `alternative.py` strictly typed: do not use `typing.Any` or `Any`, and do not add mypy or pyrefly suppression comments. Fix the annotations so public decorators remain transparent to type checkers and IDEs. +* For PyCharm-specific typing regressions, verify `typing_tests/type_probes.py` with `scripts/pycharm-type-probes.sh`. The script must produce no output when the probe file is clean. * Any change to branching paths in `alternative.py` must be followed by a branch coverage run and review for material missing runtime coverage using `uv run --dev pytest --cov=alternative --cov-branch --cov-report=term-missing:skip-covered`. * Name tests and functions in `snake_case` and give them triple-quoted docstrings similar to the current codebase. diff --git a/alternative.py b/alternative.py index 459dd40..06eed91 100644 --- a/alternative.py +++ b/alternative.py @@ -4,6 +4,7 @@ import inspect import os from functools import wraps, lru_cache +from types import FrameType from typing import ( Callable, Concatenate, @@ -11,6 +12,7 @@ Generic, ParamSpec, Protocol, + Type, TypeVar, cast, overload, @@ -54,6 +56,13 @@ def __lt__(self, other: object, /) -> bool: ... S = TypeVar("S") Owner = TypeVar("Owner") F = TypeVar("F", bound=Callable[..., object]) +# These arity variables keep callable wrappers transparent in IDEs that do not +# fully resolve ParamSpec-based __call__ methods. +A1 = TypeVar("A1") +A2 = TypeVar("A2") +A3 = TypeVar("A3") +A4 = TypeVar("A4") +A5 = TypeVar("A5") class _NoReceiver: @@ -62,24 +71,14 @@ class _NoReceiver: class _TypedDescriptor(Protocol[P, R_co]): def __get__( - self, instance: object | None, owner: type[object] | None = None, / + self, instance: object | None, owner: Type[object] | None = None, / ) -> Callable[P, R_co]: ... class _ReferenceDecorator(Protocol): @overload def __call__( - self, implementation: classmethod[Owner, P, R], / - ) -> Alternatives[_NoReceiver, P, R]: ... - - @overload - def __call__( - self, implementation: staticmethod[P, R], / - ) -> Alternatives[_NoReceiver, P, R]: ... - - @overload - def __call__( - self, implementation: Callable[Concatenate[type[Owner], P], R], / + self, implementation: Callable[Concatenate[Type[Owner], P], R], / ) -> Alternatives[_NoReceiver, P, R]: ... @overload @@ -109,6 +108,13 @@ class CrossAlternativesImplementationError(AlternativeError): """Cannot add an Implementation object that belongs to a different Alternatives set.""" +def _frame_back(frame: FrameType | None) -> FrameType | None: + """Return the previous frame when frame inspection is available.""" + if frame is None: + return None + return frame.f_back + + def _get_caller_path() -> str | None: """ Return 'module.QualName (file.py:line)' pointing to the line @@ -118,18 +124,13 @@ def _get_caller_path() -> str | None: """ frame = inspect.currentframe() # Walk back two frames: 0=this, 1=caller, 2=caller of caller - if not frame or not frame.f_back: - caller = None # no two-up frame - else: - caller = frame.f_back.f_back - - # walk though any frames that are in the current file as they will not be helpful - while caller is None or caller.f_code.co_filename == __file__: - # a bit of a jiggly approach of handling caller being None to make type checking easier and help coverage - if caller: - caller = caller.f_back - if caller is None: - return ". ()" + caller = _frame_back(_frame_back(frame)) + + # Walk through frames in this file, since they are not useful call sites. + while caller is not None and caller.f_code.co_filename == __file__: + caller = _frame_back(caller) + if caller is None: + return ". ()" code = caller.f_code module = caller.f_globals.get("__name__", "") qualname = getattr(code, "co_qualname", code.co_name) @@ -151,7 +152,7 @@ def _maybe_get_caller_path() -> str | None: def _bind_implementation( implementation: Callable[..., R] | _TypedDescriptor[P, R], instance: object | None, - owner: type[object] | None, + owner: Type[object] | None, ) -> Callable[..., R]: """Bind an implementation using descriptor semantics when available.""" if owner is None and instance is not None: @@ -166,9 +167,53 @@ def _bind_implementation( class _BoundAlternatives(Generic[S, P, R]): alternatives: Alternatives[S, P, R] instance: object | None - owner: type[object] | None + owner: Type[object] | None + + @overload + def __call__(self: _BoundAlternatives[S, [], R]) -> R: ... - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + @overload + def __call__(self: _BoundAlternatives[S, [A1], R], arg1: A1, /) -> R: ... + + @overload + def __call__( + self: _BoundAlternatives[S, [A1, A2], R], arg1: A1, arg2: A2, / + ) -> R: ... + + @overload + def __call__( + self: _BoundAlternatives[S, [A1, A2, A3], R], + arg1: A1, + arg2: A2, + arg3: A3, + /, + ) -> R: ... + + @overload + def __call__( + self: _BoundAlternatives[S, [A1, A2, A3, A4], R], + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + /, + ) -> R: ... + + @overload + def __call__( + self: _BoundAlternatives[S, [A1, A2, A3, A4, A5], R], + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + arg5: A5, + /, + ) -> R: ... + + @overload + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... + + def __call__(self, *args: object, **kwargs: object) -> R: implementation: Callable[..., R] = _bind_implementation( self.alternatives.callable, self.instance, self.owner ) @@ -177,10 +222,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: @overload def add( self: _BoundAlternatives[_NoReceiver, P, R], - implementation: Callable[P, R] - | Callable[Concatenate[type[Owner], P], R] - | classmethod[Owner, P, R] - | staticmethod[P, R], + implementation: Callable[P, R] | Callable[Concatenate[Type[Owner], P], R], *, default: bool = False, ) -> Implementation[_NoReceiver, P, R]: ... @@ -197,7 +239,7 @@ def add( def add( self, *, default: bool = False ) -> Callable[ - [Callable[..., R] | classmethod[Owner, P, R] | staticmethod[P, R]], + [Callable[..., R]], Implementation[S, P, R], ]: ... @@ -227,9 +269,53 @@ def callable(self) -> Callable[..., R] | _TypedDescriptor[P, R]: class _BoundImplementation(Generic[S, P, R]): implementation: Implementation[S, P, R] instance: object | None - owner: type[object] | None + owner: Type[object] | None + + @overload + def __call__(self: _BoundImplementation[S, [], R]) -> R: ... + + @overload + def __call__(self: _BoundImplementation[S, [A1], R], arg1: A1, /) -> R: ... + + @overload + def __call__( + self: _BoundImplementation[S, [A1, A2], R], arg1: A1, arg2: A2, / + ) -> R: ... + + @overload + def __call__( + self: _BoundImplementation[S, [A1, A2, A3], R], + arg1: A1, + arg2: A2, + arg3: A3, + /, + ) -> R: ... + + @overload + def __call__( + self: _BoundImplementation[S, [A1, A2, A3, A4], R], + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + /, + ) -> R: ... + + @overload + def __call__( + self: _BoundImplementation[S, [A1, A2, A3, A4, A5], R], + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + arg5: A5, + /, + ) -> R: ... - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + @overload + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... + + def __call__(self, *args: object, **kwargs: object) -> R: implementation: Callable[..., R] = _bind_implementation( self.implementation.implementation, self.instance, self.owner ) @@ -242,10 +328,7 @@ def alternatives(self) -> Alternatives[S, P, R]: @overload def add( self: _BoundImplementation[_NoReceiver, P, R], - implementation: Callable[P, R] - | Callable[Concatenate[type[Owner], P], R] - | classmethod[Owner, P, R] - | staticmethod[P, R], + implementation: Callable[P, R] | Callable[Concatenate[Type[Owner], P], R], *, default: bool = False, ) -> Implementation[_NoReceiver, P, R]: ... @@ -262,7 +345,7 @@ def add( def add( self, *, default: bool = False ) -> Callable[ - [Callable[..., R] | classmethod[Owner, P, R] | staticmethod[P, R]], + [Callable[..., R]], Implementation[S, P, R], ]: ... @@ -306,10 +389,7 @@ def __init__( @overload def add( self: Alternatives[_NoReceiver, P, R], - implementation: Callable[P, R] - | Callable[Concatenate[type[Owner], P], R] - | classmethod[Owner, P, R] - | staticmethod[P, R], + implementation: Callable[P, R] | Callable[Concatenate[Type[Owner], P], R], *, default: bool = False, ) -> Implementation[_NoReceiver, P, R]: ... @@ -326,7 +406,7 @@ def add( def add( self, *, default: bool = False ) -> Callable[ - [Callable[..., R] | classmethod[Owner, P, R] | staticmethod[P, R]], + [Callable[..., R]], Implementation[S, P, R], ]: ... @@ -390,16 +470,18 @@ def callable(self) -> Callable[..., R] | _TypedDescriptor[P, R]: """Return the active implementation. Setting the default implementation is disabled after this is accessed.""" - if self._callable is None: + callable_ = self._callable + if callable_ is None: # finalise the callable if self._default: - self._callable = self._default.implementation + callable_ = self._default.implementation else: - self._callable = self.reference + callable_ = self.reference + self._callable = callable_ self._debug_callable_used = _maybe_get_caller_path() # access the list of implementations to freeze them assert self.implementations - return self._callable + return callable_ @property def implementations(self) -> list[Implementation[S, P, R]]: @@ -408,6 +490,91 @@ def implementations(self) -> list[Implementation[S, P, R]]: self._debug_implementations_used = _maybe_get_caller_path() return self._implementations + @overload + def __call__(self: Alternatives[_NoReceiver, [], R]) -> R: ... + + @overload + def __call__(self: Alternatives[_NoReceiver, [A1], R], arg1: A1, /) -> R: ... + + @overload + def __call__( + self: Alternatives[_NoReceiver, [A1, A2], R], arg1: A1, arg2: A2, / + ) -> R: ... + + @overload + def __call__( + self: Alternatives[_NoReceiver, [A1, A2, A3], R], + arg1: A1, + arg2: A2, + arg3: A3, + /, + ) -> R: ... + + @overload + def __call__( + self: Alternatives[_NoReceiver, [A1, A2, A3, A4], R], + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + /, + ) -> R: ... + + @overload + def __call__( + self: Alternatives[_NoReceiver, [A1, A2, A3, A4, A5], R], + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + arg5: A5, + /, + ) -> R: ... + + @overload + def __call__(self: Alternatives[S, [], R], receiver: S, /) -> R: ... + + @overload + def __call__(self: Alternatives[S, [A1], R], receiver: S, arg1: A1, /) -> R: ... + + @overload + def __call__( + self: Alternatives[S, [A1, A2], R], receiver: S, arg1: A1, arg2: A2, / + ) -> R: ... + + @overload + def __call__( + self: Alternatives[S, [A1, A2, A3], R], + receiver: S, + arg1: A1, + arg2: A2, + arg3: A3, + /, + ) -> R: ... + + @overload + def __call__( + self: Alternatives[S, [A1, A2, A3, A4], R], + receiver: S, + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + /, + ) -> R: ... + + @overload + def __call__( + self: Alternatives[S, [A1, A2, A3, A4, A5], R], + receiver: S, + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + arg5: A5, + /, + ) -> R: ... + @overload def __call__( self: Alternatives[_NoReceiver, P, R], @@ -433,28 +600,28 @@ def __call__( @overload def __get__( - self, instance: None, owner: type[object] | None = None + self, instance: None, owner: Type[object] | None = None ) -> Alternatives[S, P, R]: ... @overload def __get__( self: Alternatives[_NoReceiver, P, R], instance: object, - owner: type[object] | None = None, + owner: Type[object] | None = None, ) -> _BoundAlternatives[_NoReceiver, P, R]: ... @overload def __get__( - self, instance: S, owner: type[object] | None = None + self, instance: S, owner: Type[object] | None = None ) -> _BoundAlternatives[S, P, R]: ... @overload def __get__( - self, instance: object, owner: type[object] | None = None + self, instance: object, owner: Type[object] | None = None ) -> Callable[Concatenate[S, P], R]: ... def __get__( - self, instance: object | None, owner: type[object] | None = None + self, instance: object | None, owner: Type[object] | None = None ) -> object: if instance is None and not self._binds_on_class_access(): return self @@ -674,6 +841,91 @@ def __repr__(self) -> str: return f"Implementation({implementation_name}, label={self.label!r})" return f"Implementation({implementation_name})" + @overload + def __call__(self: Implementation[_NoReceiver, [], R]) -> R: ... + + @overload + def __call__(self: Implementation[_NoReceiver, [A1], R], arg1: A1, /) -> R: ... + + @overload + def __call__( + self: Implementation[_NoReceiver, [A1, A2], R], arg1: A1, arg2: A2, / + ) -> R: ... + + @overload + def __call__( + self: Implementation[_NoReceiver, [A1, A2, A3], R], + arg1: A1, + arg2: A2, + arg3: A3, + /, + ) -> R: ... + + @overload + def __call__( + self: Implementation[_NoReceiver, [A1, A2, A3, A4], R], + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + /, + ) -> R: ... + + @overload + def __call__( + self: Implementation[_NoReceiver, [A1, A2, A3, A4, A5], R], + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + arg5: A5, + /, + ) -> R: ... + + @overload + def __call__(self: Implementation[S, [], R], receiver: S, /) -> R: ... + + @overload + def __call__(self: Implementation[S, [A1], R], receiver: S, arg1: A1, /) -> R: ... + + @overload + def __call__( + self: Implementation[S, [A1, A2], R], receiver: S, arg1: A1, arg2: A2, / + ) -> R: ... + + @overload + def __call__( + self: Implementation[S, [A1, A2, A3], R], + receiver: S, + arg1: A1, + arg2: A2, + arg3: A3, + /, + ) -> R: ... + + @overload + def __call__( + self: Implementation[S, [A1, A2, A3, A4], R], + receiver: S, + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + /, + ) -> R: ... + + @overload + def __call__( + self: Implementation[S, [A1, A2, A3, A4, A5], R], + receiver: S, + arg1: A1, + arg2: A2, + arg3: A3, + arg4: A4, + arg5: A5, + /, + ) -> R: ... + @overload def __call__( self: Implementation[_NoReceiver, P, R], @@ -699,28 +951,28 @@ def __call__( @overload def __get__( - self, instance: None, owner: type[object] | None = None + self, instance: None, owner: Type[object] | None = None ) -> Implementation[S, P, R]: ... @overload def __get__( self: Implementation[_NoReceiver, P, R], instance: object, - owner: type[object] | None = None, + owner: Type[object] | None = None, ) -> _BoundImplementation[_NoReceiver, P, R]: ... @overload def __get__( - self, instance: S, owner: type[object] | None = None + self, instance: S, owner: Type[object] | None = None ) -> _BoundImplementation[S, P, R]: ... @overload def __get__( - self, instance: object, owner: type[object] | None = None + self, instance: object, owner: Type[object] | None = None ) -> Callable[Concatenate[S, P], R]: ... def __get__( - self, instance: object | None, owner: type[object] | None = None + self, instance: object | None, owner: Type[object] | None = None ) -> object: if instance is None and not isinstance(self.implementation, classmethod): return self @@ -729,10 +981,7 @@ def __get__( @overload def add( self: Implementation[_NoReceiver, P, R], - implementation: Callable[P, R] - | Callable[Concatenate[type[Owner], P], R] - | classmethod[Owner, P, R] - | staticmethod[P, R], + implementation: Callable[P, R] | Callable[Concatenate[Type[Owner], P], R], *, default: bool = False, ) -> Implementation[_NoReceiver, P, R]: ... @@ -749,7 +998,7 @@ def add( def add( self, *, default: bool = False ) -> Callable[ - [Callable[..., R] | classmethod[Owner, P, R] | staticmethod[P, R]], + [Callable[..., R]], Implementation[S, P, R], ]: ... @@ -766,25 +1015,9 @@ def add( return add(implementation, default=default) -@overload -def reference(*, default: bool = False) -> _ReferenceDecorator: ... - - @overload def reference( - implementation: classmethod[Owner, P, R], *, default: bool = False -) -> Alternatives[_NoReceiver, P, R]: ... - - -@overload -def reference( - implementation: staticmethod[P, R], *, default: bool = False -) -> Alternatives[_NoReceiver, P, R]: ... - - -@overload -def reference( - implementation: Callable[Concatenate[type[Owner], P], R], + implementation: Callable[Concatenate[Type[Owner], P], R], *, default: bool = False, ) -> Alternatives[_NoReceiver, P, R]: ... @@ -802,6 +1035,10 @@ def reference( ) -> Alternatives[_NoReceiver, P, R]: ... +@overload +def reference(*, default: bool = False) -> _ReferenceDecorator: ... + + def reference( implementation: object = _UNDEFINED_VALUE, *, @@ -816,7 +1053,7 @@ def inner(f: object) -> object: default=default, ) - return cast(_ReferenceDecorator, inner) + return cast(_ReferenceDecorator, cast(object, inner)) else: return Alternatives( cast(Callable[..., object] | _TypedDescriptor[..., object], implementation), diff --git a/scripts/pycharm-type-probes.sh b/scripts/pycharm-type-probes.sh new file mode 100755 index 0000000..f6d0b30 --- /dev/null +++ b/scripts/pycharm-type-probes.sh @@ -0,0 +1,123 @@ +#!/usr/bin/env bash +set -Eeuo pipefail + +script_dir="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd -P)" +repo_root="$(cd -- "${script_dir}/.." && pwd -P)" + +inspect_sh="${PYCHARM_INSPECT_SH:-}" +if [[ -z "${inspect_sh}" ]]; then + toolbox_inspect="${HOME}/.local/share/JetBrains/Toolbox/apps/pycharm-professional/bin/inspect.sh" + if [[ -x "${toolbox_inspect}" ]]; then + inspect_sh="${toolbox_inspect}" + elif command -v inspect.sh >/dev/null 2>&1; then + inspect_sh="$(command -v inspect.sh)" + fi +fi + +if [[ -z "${inspect_sh}" || ! -x "${inspect_sh}" ]]; then + printf 'PyCharm inspect.sh was not found. Set PYCHARM_INSPECT_SH to its absolute path.\n' >&2 + exit 2 +fi + +tmp_root="$(mktemp -d "${TMPDIR:-/tmp}/alternative-pycharm-type-probes.XXXXXXXX")" +cleanup() { + case "${tmp_root}" in + "${TMPDIR:-/tmp}"/alternative-pycharm-type-probes.*) + rm -rf -- "${tmp_root}" + ;; + esac +} +trap cleanup EXIT + +profile="${tmp_root}/profile.xml" +results="${tmp_root}/results" +vmoptions="${tmp_root}/pycharm.vmoptions" +stdout_log="${tmp_root}/inspect.stdout" +stderr_log="${tmp_root}/inspect.stderr" +problems="${tmp_root}/problems.xml" + +cat >"${profile}" <<'XML' + + + + +XML + +cat >"${vmoptions}" <"${stdout_log}" 2>"${stderr_log}"; then + printf 'PyCharm inspect.sh failed.\n' >&2 + if [[ -s "${stderr_log}" ]]; then + printf '\n[stderr]\n' >&2 + tail --lines=80 "${stderr_log}" >&2 + fi + if [[ -s "${stdout_log}" ]]; then + printf '\n[stdout]\n' >&2 + tail --lines=80 "${stdout_log}" >&2 + fi + exit 2 +fi + +if [[ ! -d "${results}" ]]; then + printf 'PyCharm inspect.sh did not create an inspection results directory.\n' >&2 + exit 2 +fi + +: >"${problems}" +# PyNestedDecoratorsInspection is intentionally excluded here. PyCharm reports +# a false positive for correctly typed decorators stacked over classmethod or +# staticmethod; see PyNestedDecoratorsInspection-issue.md and related YouTrack +# issue PY-34368. +for inspection in \ + PyAssertTypeInspection \ + PyTypeCheckerInspection \ + PyUnresolvedReferencesInspection +do + report="${results}/${inspection}.xml" + if [[ ! -f "${report}" ]]; then + continue + fi + + normalized_report="${tmp_root}/${inspection}.normalized.xml" + sed -e 's##\ +#g' -e 's##\ +#g' "${report}" >"${normalized_report}" + + awk -v inspection="${inspection}" ' + // { + in_problem = 1 + relevant = 0 + block = $0 ORS + next + } + in_problem { + block = block $0 ORS + if ($0 ~ /typing_tests\/type_probes\.py/) { + relevant = 1 + } + if ($0 ~ /<\/problem>/) { + if (relevant) { + print "" + printf "%s", block + } + in_problem = 0 + relevant = 0 + block = "" + } + } + ' "${normalized_report}" >>"${problems}" +done + +if [[ -s "${problems}" ]]; then + cat "${problems}" + exit 1 +fi From f34a968f692d966a17c9b4be124f3466e057e7ed Mon Sep 17 00:00:00 2001 From: Oliver Bristow Date: Sun, 10 May 2026 12:35:00 +0100 Subject: [PATCH 5/8] pull oyt .pyi + add pyright to typing_tests A full pyright run complains about things that neither pyrefly or mypy do so it just seems a bit weak --- .github/workflows/ci.yml | 3 + AGENTS.md | 1 + alternative.py | 539 +---------------------------- alternative.pyi | 580 ++++++++++++++++++++++++++++++++ pyproject.toml | 24 ++ test_packaging.py | 29 ++ typing_tests/callable_probes.py | 51 +++ typing_tests/type_probes.py | 9 - uv.lock | 99 ++++++ 9 files changed, 803 insertions(+), 532 deletions(-) create mode 100644 alternative.pyi create mode 100644 test_packaging.py create mode 100644 typing_tests/callable_probes.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a83990c..f9aad94 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,6 +47,9 @@ jobs: - name: Pyrefly type check run: uv run --dev pyrefly check . + - name: Pyright type probe check + run: uv run --dev pyright typing_tests + - name: Mypy type check run: uv run --dev mypy . diff --git a/AGENTS.md b/AGENTS.md index de110ea..36d49d3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -5,6 +5,7 @@ The repository defines testing via GitHub actions. When contributing: * `uv run --dev ruff format --check --diff .` * `uv run --dev ruff check .` * `uv run --dev pyrefly check .` + * `uv run --dev pyright typing_tests` * `uv run --dev mypy .` * `uv run --dev pytest --verbosity=2 --cov=alternative --cov-report=xml --cov-fail-under=100 --junit-xml=test-results.xml` * `uv run --group=docs sphinx-build --fail-on-warning --keep-going --builder=html docs /tmp/alternative-docs-html` diff --git a/alternative.py b/alternative.py index 06eed91..825ded6 100644 --- a/alternative.py +++ b/alternative.py @@ -7,7 +7,6 @@ from types import FrameType from typing import ( Callable, - Concatenate, Final, Generic, ParamSpec, @@ -15,7 +14,6 @@ Type, TypeVar, cast, - overload, ) @@ -54,15 +52,7 @@ def __lt__(self, other: object, /) -> bool: ... R_co = TypeVar("R_co", covariant=True) M = TypeVar("M") S = TypeVar("S") -Owner = TypeVar("Owner") F = TypeVar("F", bound=Callable[..., object]) -# These arity variables keep callable wrappers transparent in IDEs that do not -# fully resolve ParamSpec-based __call__ methods. -A1 = TypeVar("A1") -A2 = TypeVar("A2") -A3 = TypeVar("A3") -A4 = TypeVar("A4") -A5 = TypeVar("A5") class _NoReceiver: @@ -75,23 +65,6 @@ def __get__( ) -> Callable[P, R_co]: ... -class _ReferenceDecorator(Protocol): - @overload - def __call__( - self, implementation: Callable[Concatenate[Type[Owner], P], R], / - ) -> Alternatives[_NoReceiver, P, R]: ... - - @overload - def __call__( - self, implementation: Callable[Concatenate[S, P], R], / - ) -> Alternatives[S, P, R]: ... - - @overload - def __call__( - self, implementation: Callable[P, R], / - ) -> Alternatives[_NoReceiver, P, R]: ... - - class AlternativeError(Exception): """Base class for all alternative errors.""" @@ -169,80 +142,12 @@ class _BoundAlternatives(Generic[S, P, R]): instance: object | None owner: Type[object] | None - @overload - def __call__(self: _BoundAlternatives[S, [], R]) -> R: ... - - @overload - def __call__(self: _BoundAlternatives[S, [A1], R], arg1: A1, /) -> R: ... - - @overload - def __call__( - self: _BoundAlternatives[S, [A1, A2], R], arg1: A1, arg2: A2, / - ) -> R: ... - - @overload - def __call__( - self: _BoundAlternatives[S, [A1, A2, A3], R], - arg1: A1, - arg2: A2, - arg3: A3, - /, - ) -> R: ... - - @overload - def __call__( - self: _BoundAlternatives[S, [A1, A2, A3, A4], R], - arg1: A1, - arg2: A2, - arg3: A3, - arg4: A4, - /, - ) -> R: ... - - @overload - def __call__( - self: _BoundAlternatives[S, [A1, A2, A3, A4, A5], R], - arg1: A1, - arg2: A2, - arg3: A3, - arg4: A4, - arg5: A5, - /, - ) -> R: ... - - @overload - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... - def __call__(self, *args: object, **kwargs: object) -> R: implementation: Callable[..., R] = _bind_implementation( self.alternatives.callable, self.instance, self.owner ) return implementation(*args, **kwargs) - @overload - def add( - self: _BoundAlternatives[_NoReceiver, P, R], - implementation: Callable[P, R] | Callable[Concatenate[Type[Owner], P], R], - *, - default: bool = False, - ) -> Implementation[_NoReceiver, P, R]: ... - - @overload - def add( - self, - implementation: Callable[Concatenate[S, P], R] | Implementation[S, P, R], - *, - default: bool = False, - ) -> Implementation[S, P, R]: ... - - @overload - def add( - self, *, default: bool = False - ) -> Callable[ - [Callable[..., R]], - Implementation[S, P, R], - ]: ... - def add( self, implementation: object = _UNDEFINED_VALUE, @@ -271,50 +176,6 @@ class _BoundImplementation(Generic[S, P, R]): instance: object | None owner: Type[object] | None - @overload - def __call__(self: _BoundImplementation[S, [], R]) -> R: ... - - @overload - def __call__(self: _BoundImplementation[S, [A1], R], arg1: A1, /) -> R: ... - - @overload - def __call__( - self: _BoundImplementation[S, [A1, A2], R], arg1: A1, arg2: A2, / - ) -> R: ... - - @overload - def __call__( - self: _BoundImplementation[S, [A1, A2, A3], R], - arg1: A1, - arg2: A2, - arg3: A3, - /, - ) -> R: ... - - @overload - def __call__( - self: _BoundImplementation[S, [A1, A2, A3, A4], R], - arg1: A1, - arg2: A2, - arg3: A3, - arg4: A4, - /, - ) -> R: ... - - @overload - def __call__( - self: _BoundImplementation[S, [A1, A2, A3, A4, A5], R], - arg1: A1, - arg2: A2, - arg3: A3, - arg4: A4, - arg5: A5, - /, - ) -> R: ... - - @overload - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... - def __call__(self, *args: object, **kwargs: object) -> R: implementation: Callable[..., R] = _bind_implementation( self.implementation.implementation, self.instance, self.owner @@ -325,30 +186,6 @@ def __call__(self, *args: object, **kwargs: object) -> R: def alternatives(self) -> Alternatives[S, P, R]: return self.implementation.alternatives - @overload - def add( - self: _BoundImplementation[_NoReceiver, P, R], - implementation: Callable[P, R] | Callable[Concatenate[Type[Owner], P], R], - *, - default: bool = False, - ) -> Implementation[_NoReceiver, P, R]: ... - - @overload - def add( - self, - implementation: Callable[Concatenate[S, P], R] | Implementation[S, P, R], - *, - default: bool = False, - ) -> Implementation[S, P, R]: ... - - @overload - def add( - self, *, default: bool = False - ) -> Callable[ - [Callable[..., R]], - Implementation[S, P, R], - ]: ... - def add( self, implementation: object = _UNDEFINED_VALUE, @@ -386,30 +223,6 @@ def __init__( self._debug_implementations_used: str | None = None self.add(imp, default=default) - @overload - def add( - self: Alternatives[_NoReceiver, P, R], - implementation: Callable[P, R] | Callable[Concatenate[Type[Owner], P], R], - *, - default: bool = False, - ) -> Implementation[_NoReceiver, P, R]: ... - - @overload - def add( - self, - implementation: Callable[Concatenate[S, P], R] | Implementation[S, P, R], - *, - default: bool = False, - ) -> Implementation[S, P, R]: ... - - @overload - def add( - self, *, default: bool = False - ) -> Callable[ - [Callable[..., R]], - Implementation[S, P, R], - ]: ... - def add( self, implementation: object = _UNDEFINED_VALUE, @@ -490,135 +303,13 @@ def implementations(self) -> list[Implementation[S, P, R]]: self._debug_implementations_used = _maybe_get_caller_path() return self._implementations - @overload - def __call__(self: Alternatives[_NoReceiver, [], R]) -> R: ... - - @overload - def __call__(self: Alternatives[_NoReceiver, [A1], R], arg1: A1, /) -> R: ... - - @overload - def __call__( - self: Alternatives[_NoReceiver, [A1, A2], R], arg1: A1, arg2: A2, / - ) -> R: ... - - @overload - def __call__( - self: Alternatives[_NoReceiver, [A1, A2, A3], R], - arg1: A1, - arg2: A2, - arg3: A3, - /, - ) -> R: ... - - @overload - def __call__( - self: Alternatives[_NoReceiver, [A1, A2, A3, A4], R], - arg1: A1, - arg2: A2, - arg3: A3, - arg4: A4, - /, - ) -> R: ... - - @overload - def __call__( - self: Alternatives[_NoReceiver, [A1, A2, A3, A4, A5], R], - arg1: A1, - arg2: A2, - arg3: A3, - arg4: A4, - arg5: A5, - /, - ) -> R: ... - - @overload - def __call__(self: Alternatives[S, [], R], receiver: S, /) -> R: ... - - @overload - def __call__(self: Alternatives[S, [A1], R], receiver: S, arg1: A1, /) -> R: ... - - @overload - def __call__( - self: Alternatives[S, [A1, A2], R], receiver: S, arg1: A1, arg2: A2, / - ) -> R: ... - - @overload - def __call__( - self: Alternatives[S, [A1, A2, A3], R], - receiver: S, - arg1: A1, - arg2: A2, - arg3: A3, - /, - ) -> R: ... - - @overload - def __call__( - self: Alternatives[S, [A1, A2, A3, A4], R], - receiver: S, - arg1: A1, - arg2: A2, - arg3: A3, - arg4: A4, - /, - ) -> R: ... - - @overload - def __call__( - self: Alternatives[S, [A1, A2, A3, A4, A5], R], - receiver: S, - arg1: A1, - arg2: A2, - arg3: A3, - arg4: A4, - arg5: A5, - /, - ) -> R: ... - - @overload - def __call__( - self: Alternatives[_NoReceiver, P, R], - *args: P.args, - **kwargs: P.kwargs, - ) -> R: ... - - @overload - def __call__(self, receiver: S, *args: P.args, **kwargs: P.kwargs) -> R: ... - - def __call__( - self, - receiver: object = _UNDEFINED_VALUE, - *args: object, - **kwargs: object, - ) -> R: + def __call__(self, *args: object, **kwargs: object) -> R: implementation: Callable[..., R] = _bind_implementation( self.callable, None, None ) - if isinstance(receiver, _Undefined): + if not args: return implementation(**kwargs) - return implementation(receiver, *args, **kwargs) - - @overload - def __get__( - self, instance: None, owner: Type[object] | None = None - ) -> Alternatives[S, P, R]: ... - - @overload - def __get__( - self: Alternatives[_NoReceiver, P, R], - instance: object, - owner: Type[object] | None = None, - ) -> _BoundAlternatives[_NoReceiver, P, R]: ... - - @overload - def __get__( - self, instance: S, owner: Type[object] | None = None - ) -> _BoundAlternatives[S, P, R]: ... - - @overload - def __get__( - self, instance: object, owner: Type[object] | None = None - ) -> Callable[Concatenate[S, P], R]: ... + return implementation(*args, **kwargs) def __get__( self, instance: object | None, owner: Type[object] | None = None @@ -671,19 +362,6 @@ def measure( except TypeError: return result - @overload - def pytest_parametrize( - self, - *, - only_default: bool = False, - ) -> Callable[[F], F]: ... - @overload - def pytest_parametrize( - self, - test: F, - *, - only_default: bool = False, - ) -> F: ... def pytest_parametrize( self, test: F | _Undefined = _UNDEFINED_VALUE, @@ -700,7 +378,7 @@ def pytest_parametrize( if isinstance(test, _Undefined): def decorator(f: F) -> F: - return self.pytest_parametrize(f, only_default=only_default) + return cast(F, self.pytest_parametrize(f, only_default=only_default)) return decorator @@ -715,24 +393,6 @@ def inner(*args: object, **kwargs: object) -> object: return cast(F, inner) - @overload - def pytest_parametrize_pairs( - self, - *, - n_cache: int | None = 0, - double_reference: bool = False, - only_default: bool = False, - ) -> Callable[[F], F]: ... - @overload - def pytest_parametrize_pairs( - self, - test: F, - *, - n_cache: int | None = 0, - double_reference: bool = False, - only_default: bool = False, - ) -> F: ... - def pytest_parametrize_pairs( self, test: F | _Undefined = _UNDEFINED_VALUE, @@ -756,11 +416,14 @@ def pytest_parametrize_pairs( if isinstance(test, _Undefined): def decorator(f: F) -> F: - return self.pytest_parametrize_pairs( - f, - n_cache=n_cache, - double_reference=double_reference, - only_default=only_default, + return cast( + F, + self.pytest_parametrize_pairs( + f, + n_cache=n_cache, + double_reference=double_reference, + only_default=only_default, + ), ) return decorator @@ -841,135 +504,13 @@ def __repr__(self) -> str: return f"Implementation({implementation_name}, label={self.label!r})" return f"Implementation({implementation_name})" - @overload - def __call__(self: Implementation[_NoReceiver, [], R]) -> R: ... - - @overload - def __call__(self: Implementation[_NoReceiver, [A1], R], arg1: A1, /) -> R: ... - - @overload - def __call__( - self: Implementation[_NoReceiver, [A1, A2], R], arg1: A1, arg2: A2, / - ) -> R: ... - - @overload - def __call__( - self: Implementation[_NoReceiver, [A1, A2, A3], R], - arg1: A1, - arg2: A2, - arg3: A3, - /, - ) -> R: ... - - @overload - def __call__( - self: Implementation[_NoReceiver, [A1, A2, A3, A4], R], - arg1: A1, - arg2: A2, - arg3: A3, - arg4: A4, - /, - ) -> R: ... - - @overload - def __call__( - self: Implementation[_NoReceiver, [A1, A2, A3, A4, A5], R], - arg1: A1, - arg2: A2, - arg3: A3, - arg4: A4, - arg5: A5, - /, - ) -> R: ... - - @overload - def __call__(self: Implementation[S, [], R], receiver: S, /) -> R: ... - - @overload - def __call__(self: Implementation[S, [A1], R], receiver: S, arg1: A1, /) -> R: ... - - @overload - def __call__( - self: Implementation[S, [A1, A2], R], receiver: S, arg1: A1, arg2: A2, / - ) -> R: ... - - @overload - def __call__( - self: Implementation[S, [A1, A2, A3], R], - receiver: S, - arg1: A1, - arg2: A2, - arg3: A3, - /, - ) -> R: ... - - @overload - def __call__( - self: Implementation[S, [A1, A2, A3, A4], R], - receiver: S, - arg1: A1, - arg2: A2, - arg3: A3, - arg4: A4, - /, - ) -> R: ... - - @overload - def __call__( - self: Implementation[S, [A1, A2, A3, A4, A5], R], - receiver: S, - arg1: A1, - arg2: A2, - arg3: A3, - arg4: A4, - arg5: A5, - /, - ) -> R: ... - - @overload - def __call__( - self: Implementation[_NoReceiver, P, R], - *args: P.args, - **kwargs: P.kwargs, - ) -> R: ... - - @overload - def __call__(self, receiver: S, *args: P.args, **kwargs: P.kwargs) -> R: ... - - def __call__( - self, - receiver: object = _UNDEFINED_VALUE, - *args: object, - **kwargs: object, - ) -> R: + def __call__(self, *args: object, **kwargs: object) -> R: implementation: Callable[..., R] = _bind_implementation( self.implementation, None, None ) - if isinstance(receiver, _Undefined): + if not args: return implementation(**kwargs) - return implementation(receiver, *args, **kwargs) - - @overload - def __get__( - self, instance: None, owner: Type[object] | None = None - ) -> Implementation[S, P, R]: ... - - @overload - def __get__( - self: Implementation[_NoReceiver, P, R], - instance: object, - owner: Type[object] | None = None, - ) -> _BoundImplementation[_NoReceiver, P, R]: ... - - @overload - def __get__( - self, instance: S, owner: Type[object] | None = None - ) -> _BoundImplementation[S, P, R]: ... - - @overload - def __get__( - self, instance: object, owner: Type[object] | None = None - ) -> Callable[Concatenate[S, P], R]: ... + return implementation(*args, **kwargs) def __get__( self, instance: object | None, owner: Type[object] | None = None @@ -978,30 +519,6 @@ def __get__( return self return _BoundImplementation(self, instance, owner) - @overload - def add( - self: Implementation[_NoReceiver, P, R], - implementation: Callable[P, R] | Callable[Concatenate[Type[Owner], P], R], - *, - default: bool = False, - ) -> Implementation[_NoReceiver, P, R]: ... - - @overload - def add( - self, - implementation: Callable[Concatenate[S, P], R] | Implementation[S, P, R], - *, - default: bool = False, - ) -> Implementation[S, P, R]: ... - - @overload - def add( - self, *, default: bool = False - ) -> Callable[ - [Callable[..., R]], - Implementation[S, P, R], - ]: ... - def add( self, implementation: object = _UNDEFINED_VALUE, @@ -1015,30 +532,6 @@ def add( return add(implementation, default=default) -@overload -def reference( - implementation: Callable[Concatenate[Type[Owner], P], R], - *, - default: bool = False, -) -> Alternatives[_NoReceiver, P, R]: ... - - -@overload -def reference( - implementation: Callable[Concatenate[S, P], R], *, default: bool = False -) -> Alternatives[S, P, R]: ... - - -@overload -def reference( - implementation: Callable[P, R], *, default: bool = False -) -> Alternatives[_NoReceiver, P, R]: ... - - -@overload -def reference(*, default: bool = False) -> _ReferenceDecorator: ... - - def reference( implementation: object = _UNDEFINED_VALUE, *, @@ -1053,7 +546,7 @@ def inner(f: object) -> object: default=default, ) - return cast(_ReferenceDecorator, cast(object, inner)) + return inner else: return Alternatives( cast(Callable[..., object] | _TypedDescriptor[..., object], implementation), diff --git a/alternative.pyi b/alternative.pyi new file mode 100644 index 0000000..52d3457 --- /dev/null +++ b/alternative.pyi @@ -0,0 +1,580 @@ +from __future__ import annotations + +from typing import ( + Callable, + Concatenate, + Generic, + ParamSpec, + Protocol, + TypeVar, + overload, +) + +DEBUG: bool + +__all__ = [ + "reference", + "Alternatives", + "Implementation", + "AlternativeError", + "AddTooLateError", + "MultipleDefaultsError", + "CrossAlternativesImplementationError", +] + +P = ParamSpec("P") +R = TypeVar("R") +R_co = TypeVar("R_co", covariant=True) +M = TypeVar("M") +S = TypeVar("S") +_Owner = TypeVar("_Owner") +F = TypeVar("F", bound=Callable[..., object]) +_A1 = TypeVar("_A1") +_A2 = TypeVar("_A2") +_A3 = TypeVar("_A3") +_A4 = TypeVar("_A4") +_A5 = TypeVar("_A5") + +class _NoReceiver: + """Marker type for descriptors that bind themselves before user calls.""" + +class _TypedDescriptor(Protocol[P, R_co]): + def __get__( + self, instance: object | None, owner: type[object] | None = None, / + ) -> Callable[P, R_co]: ... + +class _ReferenceDecorator(Protocol): + @overload + def __call__( + self, + implementation: classmethod[_Owner, P, R], + /, + ) -> Alternatives[_NoReceiver, P, R]: ... + @overload + def __call__( + self, + implementation: staticmethod[P, R], + /, + ) -> Alternatives[_NoReceiver, P, R]: ... + @overload + def __call__( + self, implementation: Callable[Concatenate[type[_Owner], P], R], / + ) -> Alternatives[_NoReceiver, P, R]: ... + @overload + def __call__( + self, implementation: Callable[Concatenate[S, P], R], / + ) -> Alternatives[S, P, R]: ... + @overload + def __call__( + self, implementation: Callable[P, R], / + ) -> Alternatives[_NoReceiver, P, R]: ... + +class AlternativeError(Exception): + """Base class for all alternative errors.""" + +class AddTooLateError(AlternativeError): + """Cannot add implementations after the alternatives have been invoked.""" + +class MultipleDefaultsError(AlternativeError): + """Cannot set the default implementation more than once.""" + +class CrossAlternativesImplementationError(AlternativeError): + """Cannot add an Implementation object that belongs to a different Alternatives set.""" + +class _BoundAlternatives(Generic[S, P, R]): + __match_args__: tuple[str, str, str] + alternatives: Alternatives[S, P, R] + instance: object | None + owner: type[object] | None + + def __init__( + self, + alternatives: Alternatives[S, P, R], + instance: object | None, + owner: type[object] | None, + ) -> None: ... + @overload + def __call__(self: _BoundAlternatives[S, [], R]) -> R: ... + @overload + def __call__(self: _BoundAlternatives[S, [_A1], R], arg1: _A1, /) -> R: ... + @overload + def __call__( + self: _BoundAlternatives[S, [_A1, _A2], R], arg1: _A1, arg2: _A2, / + ) -> R: ... + @overload + def __call__( + self: _BoundAlternatives[S, [_A1, _A2, _A3], R], + arg1: _A1, + arg2: _A2, + arg3: _A3, + /, + ) -> R: ... + @overload + def __call__( + self: _BoundAlternatives[S, [_A1, _A2, _A3, _A4], R], + arg1: _A1, + arg2: _A2, + arg3: _A3, + arg4: _A4, + /, + ) -> R: ... + @overload + def __call__( + self: _BoundAlternatives[S, [_A1, _A2, _A3, _A4, _A5], R], + arg1: _A1, + arg2: _A2, + arg3: _A3, + arg4: _A4, + arg5: _A5, + /, + ) -> R: ... + @overload + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... + @overload + def add( + self: _BoundAlternatives[_NoReceiver, P, R], + implementation: Callable[P, R] + | Callable[Concatenate[type[_Owner], P], R] + | classmethod[_Owner, P, R] + | staticmethod[P, R], + *, + default: bool = False, + ) -> Implementation[_NoReceiver, P, R]: ... + @overload + def add( + self, + implementation: Callable[Concatenate[S, P], R] | Implementation[S, P, R], + *, + default: bool = False, + ) -> Implementation[S, P, R]: ... + @overload + def add( + self, *, default: bool = False + ) -> Callable[[Callable[..., R]], Implementation[S, P, R]]: ... + @property + def implementations(self) -> list[Implementation[S, P, R]]: ... + @property + def reference(self) -> Implementation[S, P, R]: ... + @property + def callable(self) -> Callable[..., R] | _TypedDescriptor[P, R]: ... + +class _BoundImplementation(Generic[S, P, R]): + __match_args__: tuple[str, str, str] + implementation: Implementation[S, P, R] + instance: object | None + owner: type[object] | None + + def __init__( + self, + implementation: Implementation[S, P, R], + instance: object | None, + owner: type[object] | None, + ) -> None: ... + @overload + def __call__(self: _BoundImplementation[S, [], R]) -> R: ... + @overload + def __call__(self: _BoundImplementation[S, [_A1], R], arg1: _A1, /) -> R: ... + @overload + def __call__( + self: _BoundImplementation[S, [_A1, _A2], R], arg1: _A1, arg2: _A2, / + ) -> R: ... + @overload + def __call__( + self: _BoundImplementation[S, [_A1, _A2, _A3], R], + arg1: _A1, + arg2: _A2, + arg3: _A3, + /, + ) -> R: ... + @overload + def __call__( + self: _BoundImplementation[S, [_A1, _A2, _A3, _A4], R], + arg1: _A1, + arg2: _A2, + arg3: _A3, + arg4: _A4, + /, + ) -> R: ... + @overload + def __call__( + self: _BoundImplementation[S, [_A1, _A2, _A3, _A4, _A5], R], + arg1: _A1, + arg2: _A2, + arg3: _A3, + arg4: _A4, + arg5: _A5, + /, + ) -> R: ... + @overload + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... + @property + def alternatives(self) -> Alternatives[S, P, R]: ... + @overload + def add( + self: _BoundImplementation[_NoReceiver, P, R], + implementation: Callable[P, R] + | Callable[Concatenate[type[_Owner], P], R] + | classmethod[_Owner, P, R] + | staticmethod[P, R], + *, + default: bool = False, + ) -> Implementation[_NoReceiver, P, R]: ... + @overload + def add( + self, + implementation: Callable[Concatenate[S, P], R] | Implementation[S, P, R], + *, + default: bool = False, + ) -> Implementation[S, P, R]: ... + @overload + def add( + self, *, default: bool = False + ) -> Callable[[Callable[..., R]], Implementation[S, P, R]]: ... + +class Alternatives(Generic[S, P, R]): + reference: Implementation[S, P, R] + _default: Implementation[S, P, R] | None + _debug_default: str | None + _invoked: bool + _debug_invoked_site: str | None + _enumerated: bool + _callable: Callable[..., R] | _TypedDescriptor[P, R] | None + _debug_callable_used: str | None + _implementations: list[Implementation[S, P, R]] + _implementations_used: bool + _debug_implementations_used: str | None + + def __init__( + self, + implementation: Callable[..., R] | _TypedDescriptor[P, R], + *, + default: bool = False, + ) -> None: ... + @overload + def add( + self: Alternatives[_NoReceiver, P, R], + implementation: Callable[P, R] + | Callable[Concatenate[type[_Owner], P], R] + | classmethod[_Owner, P, R] + | staticmethod[P, R], + *, + default: bool = False, + ) -> Implementation[_NoReceiver, P, R]: ... + @overload + def add( + self, + implementation: Callable[Concatenate[S, P], R] | Implementation[S, P, R], + *, + default: bool = False, + ) -> Implementation[S, P, R]: ... + @overload + def add( + self, *, default: bool = False + ) -> Callable[[Callable[..., R]], Implementation[S, P, R]]: ... + @property + def callable(self) -> Callable[..., R] | _TypedDescriptor[P, R]: ... + @property + def implementations(self) -> list[Implementation[S, P, R]]: ... + @overload + def __call__(self: Alternatives[_NoReceiver, [], R]) -> R: ... + @overload + def __call__(self: Alternatives[_NoReceiver, [_A1], R], arg1: _A1, /) -> R: ... + @overload + def __call__( + self: Alternatives[_NoReceiver, [_A1, _A2], R], arg1: _A1, arg2: _A2, / + ) -> R: ... + @overload + def __call__( + self: Alternatives[_NoReceiver, [_A1, _A2, _A3], R], + arg1: _A1, + arg2: _A2, + arg3: _A3, + /, + ) -> R: ... + @overload + def __call__( + self: Alternatives[_NoReceiver, [_A1, _A2, _A3, _A4], R], + arg1: _A1, + arg2: _A2, + arg3: _A3, + arg4: _A4, + /, + ) -> R: ... + @overload + def __call__( + self: Alternatives[_NoReceiver, [_A1, _A2, _A3, _A4, _A5], R], + arg1: _A1, + arg2: _A2, + arg3: _A3, + arg4: _A4, + arg5: _A5, + /, + ) -> R: ... + @overload + def __call__(self: Alternatives[S, [], R], receiver: S, /) -> R: ... + @overload + def __call__(self: Alternatives[S, [_A1], R], receiver: S, arg1: _A1, /) -> R: ... + @overload + def __call__( + self: Alternatives[S, [_A1, _A2], R], receiver: S, arg1: _A1, arg2: _A2, / + ) -> R: ... + @overload + def __call__( + self: Alternatives[S, [_A1, _A2, _A3], R], + receiver: S, + arg1: _A1, + arg2: _A2, + arg3: _A3, + /, + ) -> R: ... + @overload + def __call__( + self: Alternatives[S, [_A1, _A2, _A3, _A4], R], + receiver: S, + arg1: _A1, + arg2: _A2, + arg3: _A3, + arg4: _A4, + /, + ) -> R: ... + @overload + def __call__( + self: Alternatives[S, [_A1, _A2, _A3, _A4, _A5], R], + receiver: S, + arg1: _A1, + arg2: _A2, + arg3: _A3, + arg4: _A4, + arg5: _A5, + /, + ) -> R: ... + @overload + def __call__( + self: Alternatives[_NoReceiver, P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: ... + @overload + def __call__(self, receiver: S, *args: P.args, **kwargs: P.kwargs) -> R: ... + @overload + def __get__( + self, instance: None, owner: type[object] | None = None + ) -> Alternatives[S, P, R]: ... + @overload + def __get__( + self: Alternatives[_NoReceiver, P, R], + instance: object, + owner: type[object] | None = None, + ) -> _BoundAlternatives[_NoReceiver, P, R]: ... + @overload + def __get__( + self, instance: S, owner: type[object] | None = None + ) -> _BoundAlternatives[S, P, R]: ... + @overload + def __get__( + self, instance: object, owner: type[object] | None = None + ) -> Callable[Concatenate[S, P], R]: ... + def _binds_on_class_access(self) -> bool: ... + def measure( + self, + /, + operator: Callable[[R], M], + *args: object, + **kwargs: object, + ) -> dict[Implementation[S, P, R], M]: ... + @overload + def pytest_parametrize( + self, + *, + only_default: bool = False, + ) -> Callable[[F], F]: ... + @overload + def pytest_parametrize( + self, + test: F, + *, + only_default: bool = False, + ) -> F: ... + @overload + def pytest_parametrize_pairs( + self, + *, + n_cache: int | None = 0, + double_reference: bool = False, + only_default: bool = False, + ) -> Callable[[F], F]: ... + @overload + def pytest_parametrize_pairs( + self, + test: F, + *, + n_cache: int | None = 0, + double_reference: bool = False, + only_default: bool = False, + ) -> F: ... + +class Implementation(Generic[S, P, R]): + __match_args__: tuple[str, str, str] + alternatives: Alternatives[S, P, R] + implementation: Callable[..., R] | _TypedDescriptor[P, R] + label: str | None + + def __init__( + self, + alternatives: Alternatives[S, P, R], + implementation: Callable[..., R] | _TypedDescriptor[P, R], + label: str | None = None, + ) -> None: ... + def __post_init__(self) -> None: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + @overload + def __call__(self: Implementation[_NoReceiver, [], R]) -> R: ... + @overload + def __call__(self: Implementation[_NoReceiver, [_A1], R], arg1: _A1, /) -> R: ... + @overload + def __call__( + self: Implementation[_NoReceiver, [_A1, _A2], R], arg1: _A1, arg2: _A2, / + ) -> R: ... + @overload + def __call__( + self: Implementation[_NoReceiver, [_A1, _A2, _A3], R], + arg1: _A1, + arg2: _A2, + arg3: _A3, + /, + ) -> R: ... + @overload + def __call__( + self: Implementation[_NoReceiver, [_A1, _A2, _A3, _A4], R], + arg1: _A1, + arg2: _A2, + arg3: _A3, + arg4: _A4, + /, + ) -> R: ... + @overload + def __call__( + self: Implementation[_NoReceiver, [_A1, _A2, _A3, _A4, _A5], R], + arg1: _A1, + arg2: _A2, + arg3: _A3, + arg4: _A4, + arg5: _A5, + /, + ) -> R: ... + @overload + def __call__(self: Implementation[S, [], R], receiver: S, /) -> R: ... + @overload + def __call__(self: Implementation[S, [_A1], R], receiver: S, arg1: _A1, /) -> R: ... + @overload + def __call__( + self: Implementation[S, [_A1, _A2], R], receiver: S, arg1: _A1, arg2: _A2, / + ) -> R: ... + @overload + def __call__( + self: Implementation[S, [_A1, _A2, _A3], R], + receiver: S, + arg1: _A1, + arg2: _A2, + arg3: _A3, + /, + ) -> R: ... + @overload + def __call__( + self: Implementation[S, [_A1, _A2, _A3, _A4], R], + receiver: S, + arg1: _A1, + arg2: _A2, + arg3: _A3, + arg4: _A4, + /, + ) -> R: ... + @overload + def __call__( + self: Implementation[S, [_A1, _A2, _A3, _A4, _A5], R], + receiver: S, + arg1: _A1, + arg2: _A2, + arg3: _A3, + arg4: _A4, + arg5: _A5, + /, + ) -> R: ... + @overload + def __call__( + self: Implementation[_NoReceiver, P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: ... + @overload + def __call__(self, receiver: S, *args: P.args, **kwargs: P.kwargs) -> R: ... + @overload + def __get__( + self, instance: None, owner: type[object] | None = None + ) -> Implementation[S, P, R]: ... + @overload + def __get__( + self: Implementation[_NoReceiver, P, R], + instance: object, + owner: type[object] | None = None, + ) -> _BoundImplementation[_NoReceiver, P, R]: ... + @overload + def __get__( + self, instance: S, owner: type[object] | None = None + ) -> _BoundImplementation[S, P, R]: ... + @overload + def __get__( + self, instance: object, owner: type[object] | None = None + ) -> Callable[Concatenate[S, P], R]: ... + @overload + def add( + self: Implementation[_NoReceiver, P, R], + implementation: Callable[P, R] + | Callable[Concatenate[type[_Owner], P], R] + | classmethod[_Owner, P, R] + | staticmethod[P, R], + *, + default: bool = False, + ) -> Implementation[_NoReceiver, P, R]: ... + @overload + def add( + self, + implementation: Callable[Concatenate[S, P], R] | Implementation[S, P, R], + *, + default: bool = False, + ) -> Implementation[S, P, R]: ... + @overload + def add( + self, *, default: bool = False + ) -> Callable[[Callable[..., R]], Implementation[S, P, R]]: ... + +@overload +def reference( + implementation: classmethod[_Owner, P, R], + *, + default: bool = False, +) -> Alternatives[_NoReceiver, P, R]: ... +@overload +def reference( + implementation: staticmethod[P, R], + *, + default: bool = False, +) -> Alternatives[_NoReceiver, P, R]: ... +@overload +def reference( + implementation: Callable[Concatenate[type[_Owner], P], R], + *, + default: bool = False, +) -> Alternatives[_NoReceiver, P, R]: ... +@overload +def reference( + implementation: Callable[Concatenate[S, P], R], *, default: bool = False +) -> Alternatives[S, P, R]: ... +@overload +def reference( + implementation: Callable[P, R], *, default: bool = False +) -> Alternatives[_NoReceiver, P, R]: ... +@overload +def reference(*, default: bool = False) -> _ReferenceDecorator: ... +def _get_caller_path() -> str | None: ... diff --git a/pyproject.toml b/pyproject.toml index 02f2d74..546fb42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,10 @@ dev = [ [tool.uv] dev-dependencies = [ "alternative[dev]", + "build>=1.2.2", + "hatchling>=1.27.0", "mypy>=2.0.0", + "pyright>=1.1.407", "pyrefly>=0.64.1", "ruff>=0.15.12", "typing-extensions>=4.15.0", @@ -60,6 +63,27 @@ alternative = { workspace = true } requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.build.targets.sdist] +include = [ + "/alternative.py", + "/alternative.pyi", + "/docs", + "/examples", + "/LICENSE", + "/pyproject.toml", + "/README.md", + "/scripts", + "/test_*.py", + "/typing_tests", + "/uv.lock", +] + +[tool.hatch.build.targets.wheel] +only-include = [ + "alternative.py", + "alternative.pyi", +] + [tool.pytest.ini_options] # do 5 rounds of 0.01 benchmarks, as the benchmarks are examples or very fast addopts = "--cov=alternative --cov-report=html --benchmark-max-time=0.01" diff --git a/test_packaging.py b/test_packaging.py new file mode 100644 index 0000000..bee32b4 --- /dev/null +++ b/test_packaging.py @@ -0,0 +1,29 @@ +import subprocess +import sys +import zipfile +from pathlib import Path + + +def test_wheel_includes_stub(tmp_path: Path) -> None: + """The built wheel ships the top-level stub used by type checkers.""" + subprocess.run( + [ + sys.executable, + "-m", + "build", + "--wheel", + "--no-isolation", + "--outdir", + str(tmp_path), + ], + check=True, + cwd=Path(__file__).resolve().parent, + ) + + wheels = sorted(tmp_path.glob("alternative-*.whl")) + assert len(wheels) == 1 + with zipfile.ZipFile(wheels[0]) as wheel: + names = set(wheel.namelist()) + + assert "alternative.py" in names + assert "alternative.pyi" in names diff --git a/typing_tests/callable_probes.py b/typing_tests/callable_probes.py new file mode 100644 index 0000000..6656486 --- /dev/null +++ b/typing_tests/callable_probes.py @@ -0,0 +1,51 @@ +"""Type-level probes for assigning alternative wrappers to Callable types.""" + +from collections.abc import Callable + +import alternative + + +@alternative.reference +def normal_function(count: int, label: str) -> str: + """Return a labelled value repeated a requested number of times.""" + return label * count + + +normal_function_callable: Callable[[int, str], str] = normal_function + + +class NormalMethods: + """Container for instance method callable assignment probes.""" + + @alternative.reference + def method(self, count: int, label: str) -> str: + """Return a labelled value repeated a requested number of times.""" + return label * count + + +normal_methods = NormalMethods() +bound_method_callable: Callable[[int, str], str] = normal_methods.method +unbound_method_callable: Callable[[NormalMethods, int, str], str] = NormalMethods.method + + +class DescriptorMethods: + """Container for classmethod and staticmethod callable assignment probes.""" + + @alternative.reference + @classmethod + def build(cls, count: int, label: str) -> str: + """Return a labelled value repeated a requested number of times.""" + return label * count + + @alternative.reference + @staticmethod + def parse(count: int, label: str) -> str: + """Return a labelled value repeated a requested number of times.""" + return label * count + + +classmethod_callable: Callable[[int, str], str] = DescriptorMethods.build +bound_classmethod_callable: Callable[[int, str], str] = DescriptorMethods().build + +staticmethod_callable: Callable[[int, str], str] = DescriptorMethods.parse +bound_staticmethod_callable: Callable[[int, str], str] = DescriptorMethods().parse diff --git a/typing_tests/type_probes.py b/typing_tests/type_probes.py index df4e50e..0ab5c30 100644 --- a/typing_tests/type_probes.py +++ b/typing_tests/type_probes.py @@ -4,8 +4,6 @@ type-checking stages, which run mypy and pyrefly over the whole repository. """ -from collections.abc import Callable - from typing_extensions import assert_type import alternative @@ -18,7 +16,6 @@ def normal_function(count: int, label: str) -> str: assert_type(normal_function(2, "a"), str) -normal_function_callable: Callable[[int, str], str] = normal_function class NormalMethods: @@ -38,8 +35,6 @@ def method_extra(self, count: int, label: str) -> str: normal_methods = NormalMethods() assert_type(normal_methods.method(2, "a"), str) assert_type(normal_methods.method_extra(2, "a"), str) -bound_method_callable: Callable[[int, str], str] = normal_methods.method -unbound_method_callable: Callable[[NormalMethods, int, str], str] = NormalMethods.method class DescriptorMethods: @@ -73,11 +68,7 @@ def parse_extra(count: int, label: str) -> str: assert_type(DescriptorMethods.build(2, "a"), str) assert_type(DescriptorMethods().build(2, "a"), str) assert_type(DescriptorMethods.build_extra(2, "a"), str) -classmethod_callable: Callable[[int, str], str] = DescriptorMethods.build -bound_classmethod_callable: Callable[[int, str], str] = DescriptorMethods().build assert_type(DescriptorMethods.parse(2, "a"), str) assert_type(DescriptorMethods().parse(2, "a"), str) assert_type(DescriptorMethods.parse_extra(2, "a"), str) -staticmethod_callable: Callable[[int, str], str] = DescriptorMethods.parse -bound_staticmethod_callable: Callable[[int, str], str] = DescriptorMethods().parse diff --git a/uv.lock b/uv.lock index 1c78cc9..9fbae1e 100644 --- a/uv.lock +++ b/uv.lock @@ -28,8 +28,11 @@ dev = [ [package.dev-dependencies] dev = [ + { name = "build" }, + { name = "hatchling" }, { name = "mypy" }, { name = "pyrefly" }, + { name = "pyright" }, { name = "pytest" }, { name = "pytest-benchmark" }, { name = "pytest-cov" }, @@ -51,8 +54,11 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "alternative", extras = ["dev"], editable = "." }, + { name = "build", specifier = ">=1.2.2" }, + { name = "hatchling", specifier = ">=1.27.0" }, { name = "mypy", specifier = ">=2.0.0" }, { name = "pyrefly", specifier = ">=0.64.1" }, + { name = "pyright", specifier = ">=1.1.407" }, { name = "ruff", specifier = ">=0.15.12" }, { name = "typing-extensions", specifier = ">=4.15.0" }, ] @@ -108,6 +114,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/f5/21d2de20e8b8b0408f0681956ca2c69f1320a3848ac50e6e7f39c6159675/babel-2.18.0-py3-none-any.whl", hash = "sha256:e2b422b277c2b9a9630c1d7903c2a00d0830c409c59ac8cae9081c92f1aeba35", size = 10196845 }, ] +[[package]] +name = "build" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "os_name == 'nt'" }, + { name = "importlib-metadata", marker = "python_full_version < '3.10.2'" }, + { name = "packaging" }, + { name = "pyproject-hooks" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/e0/df5e171f685f82f37b12e1f208064e24244911079d7b767447d1af7e0d70/build-1.5.0.tar.gz", hash = "sha256:302c22c3ba2a0fd5f3911918651341ebb3896176cbdec15bd421f80b1afc7647", size = 89796 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/fe/6bea5c9162869c5beba5d9c8abbed835ec85bf1ec1fba05a3822325c45f3/build-1.5.0-py3-none-any.whl", hash = "sha256:13f3eecb844759ab66efec90ca17639bbf14dc06cb2fdf37a9010322d9c50a6f", size = 26018 }, +] + [[package]] name = "certifi" version = "2026.4.22" @@ -321,6 +343,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740 }, ] +[[package]] +name = "hatchling" +version = "1.29.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "pathspec" }, + { name = "pluggy" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "trove-classifiers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cf/9c/b4cfe330cd4f49cff17fd771154730555fa4123beb7f292cf0098b4e6c20/hatchling-1.29.0.tar.gz", hash = "sha256:793c31816d952cee405b83488ce001c719f325d9cda69f1fc4cd750527640ea6", size = 55656 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d3/8a/44032265776062a89171285ede55a0bdaadc8ac00f27f0512a71a9e3e1c8/hatchling-1.29.0-py3-none-any.whl", hash = "sha256:50af9343281f34785fab12da82e445ed987a6efb34fd8c2fc0f6e6630dbcc1b0", size = 76356 }, +] + [[package]] name = "idna" version = "3.13" @@ -339,6 +377,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5f/53/fb7122b71361a0d121b669dcf3d31244ef75badbbb724af388948de543e2/imagesize-2.0.0-py2.py3-none-any.whl", hash = "sha256:5667c5bbb57ab3f1fa4bc366f4fbc971db3d5ed011fd2715fd8001f782718d96", size = 9441 }, ] +[[package]] +name = "importlib-metadata" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp", marker = "python_full_version < '3.15'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a9/01/15bb152d77b21318514a96f43af312635eb2500c96b55398d020c93d86ea/importlib_metadata-9.0.0.tar.gz", hash = "sha256:a4f57ab599e6a2e3016d7595cfd72eb4661a5106e787a95bcc90c7105b831efc", size = 56405 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/3d/2d244233ac4f76e38533cfcb2991c9eb4c7bf688ae0a036d30725b8faafe/importlib_metadata-9.0.0-py3-none-any.whl", hash = "sha256:2d21d1cc5a017bd0559e36150c21c830ab1dc304dedd1b7ea85d20f45ef3edd7", size = 27789 }, +] + [[package]] name = "iniconfig" version = "2.1.0" @@ -598,6 +648,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963 }, ] +[[package]] +name = "nodeenv" +version = "1.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438 }, +] + [[package]] name = "packaging" version = "25.0" @@ -643,6 +702,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f4/7e/a72dd26f3b0f4f2bf1dd8923c85f7ceb43172af56d63c7383eb62b332364/pygments-2.20.0-py3-none-any.whl", hash = "sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176", size = 1231151 }, ] +[[package]] +name = "pyproject-hooks" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/82/28175b2414effca1cdac8dc99f76d660e7a4fb0ceefa4b4ab8f5f6742925/pyproject_hooks-1.2.0.tar.gz", hash = "sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8", size = 19228 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/24/12818598c362d7f300f18e74db45963dbcb85150324092410c8b49405e42/pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913", size = 10216 }, +] + [[package]] name = "pyrefly" version = "0.64.1" @@ -660,6 +728,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/7c/449407653fe95e3f3a65dd8a54d8729ac0451247489d79d4e07808d73917/pyrefly-0.64.1-py3-none-win_arm64.whl", hash = "sha256:8f83a74c1463842d486d6578a000feccf47cd54d6d7d6628ffe73b1055ca9dce", size = 12528438 }, ] +[[package]] +name = "pyright" +version = "1.1.409" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodeenv" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/51/4e/3aa27f74211522dba7e9cbc3e74de779c6d4b654c54e50a4840623be8014/pyright-1.1.409.tar.gz", hash = "sha256:986ee05beca9e077c165758ad123667c679e050059a2546aa02473930394bc93", size = 4430434 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/6b/330d8ebae582b30c2959a1ef4c3bc344ebde48c2ff0c3f113c4710735e11/pyright-1.1.409-py3-none-any.whl", hash = "sha256:aa3ea228cab90c845c7a60d28db7a844c04315356392aa09fafcee98c8c22fb3", size = 6438161 }, +] + [[package]] name = "pytest" version = "8.3.5" @@ -914,6 +995,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/61/cceae43728b7de99d9b847560c262873a1f6c98202171fd5ed62640b494b/tomli-2.4.1-py3-none-any.whl", hash = "sha256:0d85819802132122da43cb86656f8d1f8c6587d54ae7dcaf30e90533028b49fe", size = 14583 }, ] +[[package]] +name = "trove-classifiers" +version = "2026.5.7.17" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/35/68/175e7c07c5be13200387d5c0995b0da1e198e360047c08eb17d1002fcd92/trove_classifiers-2026.5.7.17.tar.gz", hash = "sha256:a04a48f8f0a787cb996514d3969ac7608aa3c60cb15d073c1e02801e60533e80", size = 17041 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/e3/d81b065a2d866a33a541ac63a2a4cc5737e03ce2379ac3191c98bb8867e3/trove_classifiers-2026.5.7.17-py3-none-any.whl", hash = "sha256:5ec0800de5e2ddbd7c663cb4c0c15328f132dc168813897c18866c5c7b93db33", size = 14201 }, +] + [[package]] name = "typing-extensions" version = "4.15.0" @@ -931,3 +1021,12 @@ sdist = { url = "https://files.pythonhosted.org/packages/53/0c/06f8b233b8fd13b9e wheels = [ { url = "https://files.pythonhosted.org/packages/7f/3e/5db95bcf282c52709639744ca2a8b149baccf648e39c8cc87553df9eae0c/urllib3-2.7.0-py3-none-any.whl", hash = "sha256:9fb4c81ebbb1ce9531cce37674bbc6f1360472bc18ca9a553ede278ef7276897", size = 131087 }, ] + +[[package]] +name = "zipp" +version = "3.23.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/21/093488dfc7cc8964ded15ab726fad40f25fd3d788fd741cc1c5a17d78ee8/zipp-3.23.1.tar.gz", hash = "sha256:32120e378d32cd9714ad503c1d024619063ec28aad2248dc6672ad13edfa5110", size = 25965 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/8a/0861bec20485572fbddf3dfba2910e38fe249796cb73ecdeb74e07eeb8d3/zipp-3.23.1-py3-none-any.whl", hash = "sha256:0b3596c50a5c700c9cb40ba8d86d9f2cc4807e9bedb06bcdf7fac85633e444dc", size = 10378 }, +] From 969003f9faa07d353f32199862899f19bd88ff7f Mon Sep 17 00:00:00 2001 From: Oliver Bristow Date: Sun, 10 May 2026 13:19:48 +0100 Subject: [PATCH 6/8] test docs --- conftest.py | 9 +++++++++ docs/quickstart.rst | 14 +++++++++++--- docs/workflow.rst | 18 ++++++++++++++++++ pyproject.toml | 2 ++ uv.lock | 12 ++++++++++++ 5 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 conftest.py diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..1adc1b1 --- /dev/null +++ b/conftest.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from sybil import Sybil +from sybil.parsers.rest import PythonCodeBlockParser + +pytest_collect_file = Sybil( + parsers=[PythonCodeBlockParser()], + patterns=["*.rst"], +).pytest() diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 6623d62..e52c4f9 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -42,12 +42,20 @@ runtime implementation: .. code-block:: python - @normalise_name.add(default=True) - def normalise_name_fast(name: str) -> str: + import alternative + + + @alternative.reference + def display_name(name: str) -> str: + return " ".join(part.capitalize() for part in name.split()) + + + @display_name.add(default=True) + def display_name_fast(name: str) -> str: return name.title() - assert normalise_name("grace hopper") == "Grace Hopper" + assert display_name("grace hopper") == "Grace Hopper" Only one explicit default can be registered. This catches accidental import order changes where two modules both try to choose the active implementation. diff --git a/docs/workflow.rst b/docs/workflow.rst index 71bd83b..4ce2462 100644 --- a/docs/workflow.rst +++ b/docs/workflow.rst @@ -90,6 +90,24 @@ to distinguish. Copying Implementations ----------------------- +.. invisible-code-block: python + + import alternative + + @alternative.reference + def source_alternatives() -> int: + return 1 + + @source_alternatives.add + def source_candidate() -> int: + return 2 + + @alternative.reference + def target() -> int: + return 10 + + source_implementation = source_alternatives.implementations[1] + An :class:`alternative.Alternatives` object can be added to another alternatives set, which copies its underlying reference implementation: diff --git a/pyproject.toml b/pyproject.toml index 546fb42..ac6248d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dev = [ "pytest>=8.3.5", "pytest-benchmark>=5.1.0", "pytest-cov>=6.1.1", + "sybil>=9.3.0,<10", ] [tool.uv] @@ -67,6 +68,7 @@ build-backend = "hatchling.build" include = [ "/alternative.py", "/alternative.pyi", + "/conftest.py", "/docs", "/examples", "/LICENSE", diff --git a/uv.lock b/uv.lock index 9fbae1e..807d43b 100644 --- a/uv.lock +++ b/uv.lock @@ -24,6 +24,7 @@ dev = [ { name = "pytest" }, { name = "pytest-benchmark" }, { name = "pytest-cov" }, + { name = "sybil" }, ] [package.dev-dependencies] @@ -37,6 +38,7 @@ dev = [ { name = "pytest-benchmark" }, { name = "pytest-cov" }, { name = "ruff" }, + { name = "sybil" }, { name = "typing-extensions" }, ] docs = [ @@ -49,6 +51,7 @@ requires-dist = [ { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.5" }, { name = "pytest-benchmark", marker = "extra == 'dev'", specifier = ">=5.1.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=6.1.1" }, + { name = "sybil", marker = "extra == 'dev'", specifier = ">=9.3.0,<10" }, ] [package.metadata.requires-dev] @@ -941,6 +944,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/a7/d2782e4e3f77c8450f727ba74a8f12756d5ba823d81b941f1b04da9d033a/sphinxcontrib_serializinghtml-2.0.0-py3-none-any.whl", hash = "sha256:6e2cb0eef194e10c27ec0023bfeb25badbbb5868244cf5bc5bdc04e4464bf331", size = 92072 }, ] +[[package]] +name = "sybil" +version = "9.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c6/46/bae21847b8d761ddd6ede1811d32818342dbd482c32a2a5805c28d9b2f18/sybil-9.3.0.tar.gz", hash = "sha256:847d1d17b8a857c4bb3f8471b4a57b8affa939a60fbf507e70aa72ad79097c05", size = 89078 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/08/cd3cf2a82570748cfb3142e795044197deff81ad3b70a0b9a9c22331e70a/sybil-9.3.0-py3-none-any.whl", hash = "sha256:0b108b980ab9fac774953042b07fcb5858aa19a38404d0cb42c30c93423ac0c1", size = 39286 }, +] + [[package]] name = "tomli" version = "2.4.1" From af0c3aa862b54af8c8ebcaa0e64db13c2d456900 Mon Sep 17 00:00:00 2001 From: Oliver Bristow Date: Sun, 10 May 2026 14:08:14 +0100 Subject: [PATCH 7/8] Update README.md and lean on pydocs instead of examples/ --- README.md | 69 +++++++++++++++++++++++++++++++++--- examples/test_benchmark.py | 19 ---------- examples/test_equivalence.py | 26 -------------- examples/test_measure.py | 43 ---------------------- pyproject.toml | 3 +- test_alternative.py | 48 +++++++++++++++++++++++++ test_pytest_util.py | 17 +++++++++ 7 files changed, 130 insertions(+), 95 deletions(-) delete mode 100644 examples/test_benchmark.py delete mode 100644 examples/test_equivalence.py delete mode 100644 examples/test_measure.py diff --git a/README.md b/README.md index 57aa6ef..d9ba8f9 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,8 @@ When optimizing a hot path, it’s common to accumulate: `alternative` keeps that workflow tidy by making implementation registration and selection first-class. +The same model works for module functions, instance methods, class methods, and static methods. Public typing is shipped in [`alternative.pyi`](alternative.pyi), so type checkers and IDEs can see the original call signatures instead of losing them behind the decorator objects. + ## Quick example ```python @@ -45,23 +47,72 @@ assert constant_number() == 2 assert unused_alternative_constant_number() == 3 ``` +See the [quickstart](https://alternative.readthedocs.io/en/latest/quickstart.html) for registration patterns, defaults, and method examples. + +## Methods and descriptors + +Decorate instance methods directly. For `@classmethod` and `@staticmethod`, put `@alternative.reference` and `.add(...)` outside the built-in descriptor decorator: + +```python +import alternative + + +class Parser: + def __init__(self, value: str = ""): + self.value = value + + @alternative.reference + def parse(self, value: str) -> int: + return int(value.strip()) + + @parse.add(default=True) + def parse_fast(self, value: str) -> int: + return int(value) + + @alternative.reference + @classmethod + def from_text(cls, value: str) -> "Parser": + return cls(value.strip()) + + @from_text.add(default=True) + @classmethod + def from_text_fast(cls, value: str) -> "Parser": + return cls(value) + + @alternative.reference + @staticmethod + def is_valid(value: str) -> bool: + return value.strip().isdigit() + + @is_valid.add(default=True) + @staticmethod + def is_valid_fast(value: str) -> bool: + return value.isdigit() +``` + +Calling through an instance or class follows normal Python binding rules, and direct implementation calls bind the same way. The full descriptor examples are in [Use Methods](https://alternative.readthedocs.io/en/latest/quickstart.html#use-methods) and [Testing Methods](https://alternative.readthedocs.io/en/latest/pytest.html#testing-methods). + ## Pytest features -The examples directory includes practical pytest patterns that make this library shine. +The pytest helpers are documented in the [pytest integration guide](https://alternative.readthedocs.io/en/latest/pytest.html). ### Pairwise equivalence checks Use `pytest_parametrize_pairs(...)` to compare the reference against each candidate implementation. -- Basic pairwise checks: [`examples/test_measure.py`](examples/test_measure.py) -- More configurable pairwise checks: [`examples/test_equivalence.py`](examples/test_equivalence.py) +- [Equivalence Tests](https://alternative.readthedocs.io/en/latest/pytest.html#equivalence-tests) +- [Reference Caching](https://alternative.readthedocs.io/en/latest/pytest.html#reference-caching) ### Single-implementation parametrization Use `pytest_parametrize(...)` to run one test body across all implementations. -- Great for benchmark workflows with [`pytest-benchmark`](https://pypi.org/project/pytest-benchmark/): [`examples/test_benchmark.py`](examples/test_benchmark.py) -- Useful for validating that every implementation passes one shared test suite +- [Only the Default Implementation](https://alternative.readthedocs.io/en/latest/pytest.html#only-the-default-implementation) +- [Benchmark All Implementations](https://alternative.readthedocs.io/en/latest/pytest.html#benchmark-all-implementations) with [`pytest-benchmark`](https://pypi.org/project/pytest-benchmark/) + +## Runtime tools + +`Alternatives.measure(...)` runs every implementation with the same arguments and measures the results with a callable you provide. See [Measure Implementations](https://alternative.readthedocs.io/en/latest/workflow.html#measure-implementations). ## Safety guarantees @@ -75,3 +126,11 @@ The library tries to avoid unpleasant surprises caused by import order or accide Set `ALTERNATIVE_DEBUG=1` to record where critical state changes happened (like selecting defaults or inspecting implementations). These locations are surfaced in error messages to make stateful issues easier to track down. When debug mode is enabled, each `Implementation` also captures a label with its registration call-site. This label appears in `repr(...)` and selected debug errors, making it easier to disambiguate implementation instances. + +## Typing and IDEs + +`alternative` ships a top-level stub file, [`alternative.pyi`](alternative.pyi), for the public typing surface. It includes overloads for descriptor binding, transparent method/classmethod/staticmethod decoration, and the pytest helpers, while [`alternative.py`](alternative.py) stays focused on runtime behavior. + +The typing probes are checked with mypy, pyright, pyrefly, and a headless PyCharm inspection script: [`scripts/pycharm-type-probes.sh`](scripts/pycharm-type-probes.sh). The PyCharm probe covers type assertions, unresolved references, and type checker warnings in [`typing_tests/type_probes.py`](typing_tests/type_probes.py). + +Known PyCharm caveat: JetBrains `PyNestedDecoratorsInspection` currently reports a false-positive for correctly typed decorators stacked outside `@classmethod` or `@staticmethod`. Runtime behavior and type resolution are correct, and the project does not require `# noinspection PyTypeChecker` call-site suppressions for these examples. diff --git a/examples/test_benchmark.py b/examples/test_benchmark.py deleted file mode 100644 index 8aefbc9..0000000 --- a/examples/test_benchmark.py +++ /dev/null @@ -1,19 +0,0 @@ -import alternative - - -@alternative.reference -def reference_implementation(): - """Reference implementation.""" - return 1 - - -@reference_implementation.add -def alternative_implementation1(): - """Another implementation.""" - return int(True) - - -@reference_implementation.pytest_parametrize(only_default=False) -def test_f(benchmark, implementation): - """Benchmark all implementations using the pytest-benchmark `benchmark` fixture.""" - assert benchmark(implementation) == 1 diff --git a/examples/test_equivalence.py b/examples/test_equivalence.py deleted file mode 100644 index 63a8f16..0000000 --- a/examples/test_equivalence.py +++ /dev/null @@ -1,26 +0,0 @@ -import alternative -import cmath - - -@alternative.reference -def reference_implementation(): - """Reference implementation.""" - return 1 - - -@reference_implementation.add -def alternative_implementation1(): - """Another implementation.""" - return int(True) - - -@reference_implementation.add -def alternative_implementation2(): - """Yet another implementation.""" - return abs(cmath.exp(1j * cmath.pi)) - - -@reference_implementation.pytest_parametrize_pairs(n_cache=None, only_default=False) -def test_f(reference, implementation): - """Compare the output of the reference (with caching) and each alternative implementation.""" - assert reference() == implementation() diff --git a/examples/test_measure.py b/examples/test_measure.py deleted file mode 100644 index 89a0936..0000000 --- a/examples/test_measure.py +++ /dev/null @@ -1,43 +0,0 @@ -import alternative - - -@alternative.reference -def make_four(): - """Reference implementation.""" - return "1 + 1 + 1 + 1" - - -@make_four.add -def make_four_factor(): - """Another implementation.""" - return "2 * 2" - - -@make_four.add -def make_four_literal(): - """Another implementation.""" - return "4" - - -@make_four.pytest_parametrize_pairs() -def test_f(reference, implementation): - """Compare the output of the reference (with caching) and each alternative implementation.""" - assert eval(reference()) == eval(implementation()) - - -def test_measure(): - """The measure is applied to all the results""" - measurements = make_four.measure(len) - assert [(i.implementation.__name__, m) for i, m in measurements.items()] == list( - {"make_four_literal": 1, "make_four_factor": 5, "make_four": 13}.items() - ) - - -def test_measure_unsortable(): - """The measure is applied to all the results""" - # convert the length to a complex number to make it unsortable - measurements = make_four.measure(lambda code: len(code) + 0j) - # the measurements are in the order of the implementations - assert [(i.implementation.__name__, m) for i, m in measurements.items()] == list( - {"make_four": 13, "make_four_factor": 5, "make_four_literal": 1}.items() - ) diff --git a/pyproject.toml b/pyproject.toml index ac6248d..591efa3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,6 @@ include = [ "/alternative.pyi", "/conftest.py", "/docs", - "/examples", "/LICENSE", "/pyproject.toml", "/README.md", @@ -87,7 +86,7 @@ only-include = [ ] [tool.pytest.ini_options] -# do 5 rounds of 0.01 benchmarks, as the benchmarks are examples or very fast +# keep benchmark-backed tests bounded when present addopts = "--cov=alternative --cov-report=html --benchmark-max-time=0.01" diff --git a/test_alternative.py b/test_alternative.py index 5f15432..3de535e 100644 --- a/test_alternative.py +++ b/test_alternative.py @@ -251,6 +251,54 @@ def f2(): assert f1.add(alt1) is not alt1 +def test_measure_sorts_sortable_measurements() -> None: + """Measurements are sorted by the measured value when the values are sortable.""" + + @alternative.reference + def make_four() -> str: + return "1 + 1 + 1 + 1" + + @make_four.add + def make_four_factor() -> str: + return "2 * 2" + + @make_four.add + def make_four_literal() -> str: + return "4" + + measurements = make_four.measure(len) + + assert list(measurements.items()) == [ + (make_four_literal, 1), + (make_four_factor, 5), + (make_four.reference, 13), + ] + + +def test_measure_preserves_registration_order_for_unsortable_measurements() -> None: + """Measurements keep registration order when the measured values cannot be sorted.""" + + @alternative.reference + def make_four() -> str: + return "1 + 1 + 1 + 1" + + @make_four.add + def make_four_factor() -> str: + return "2 * 2" + + @make_four.add + def make_four_literal() -> str: + return "4" + + measurements = make_four.measure(lambda code: len(code) + 0j) + + assert list(measurements.items()) == [ + (make_four.reference, 13 + 0j), + (make_four_factor, 5 + 0j), + (make_four_literal, 1 + 0j), + ] + + def test_cross_owner_add_error(): """Adding a cross-owner implementation raises a dedicated explicit error.""" diff --git a/test_pytest_util.py b/test_pytest_util.py index a9c310c..12cb2f0 100644 --- a/test_pytest_util.py +++ b/test_pytest_util.py @@ -116,6 +116,23 @@ def parametrized(implementation: Callable[[], int]) -> None: assert selected == [reference_impl.reference.implementation] +def test_pytest_parametrize_invokes_wrapped_test() -> None: + """Implementation parametrization delegates to the original test body.""" + + @alternative.reference + def reference_impl(value: int) -> int: + return value + + def parametrized(implementation: Callable[[int], int], value: int) -> int: + """Placeholder test used to inspect direct decorated invocation.""" + return implementation(value) + + decorated = reference_impl.pytest_parametrize(parametrized) + + assert inspect.signature(decorated) == inspect.signature(parametrized) + assert decorated(reference_impl, 3) == 3 + + @pytest.mark.parametrize("only_default", [False, True]) @pytest.mark.parametrize("double_reference", [False, True]) def test_pytest_parametrize_pairs_signature_and_parameters( From c2b3e07a97cc8aeefdfbc36b275cbea1a27c5f4e Mon Sep 17 00:00:00 2001 From: Oliver Bristow Date: Sun, 10 May 2026 14:20:08 +0100 Subject: [PATCH 8/8] Add stubtest to dev checks --- .github/workflows/ci.yml | 3 +++ AGENTS.md | 1 + 2 files changed, 4 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f9aad94..6a2b502 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -53,6 +53,9 @@ jobs: - name: Mypy type check run: uv run --dev mypy . + - name: Stubtest runtime typing check + run: uv run --dev stubtest alternative + - name: Run tests run: | # Some GitHub-hosted runners export PYTEST_DISABLE_PLUGIN_AUTOLOAD=1. diff --git a/AGENTS.md b/AGENTS.md index 36d49d3..b2fd73a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -7,6 +7,7 @@ The repository defines testing via GitHub actions. When contributing: * `uv run --dev pyrefly check .` * `uv run --dev pyright typing_tests` * `uv run --dev mypy .` + * `uv run --dev stubtest alternative` * `uv run --dev pytest --verbosity=2 --cov=alternative --cov-report=xml --cov-fail-under=100 --junit-xml=test-results.xml` * `uv run --group=docs sphinx-build --fail-on-warning --keep-going --builder=html docs /tmp/alternative-docs-html` * Format code with `uv run --dev ruff format .` before committing.