From 3d8c724c0432e74fe37025ab314a9d9291a9e349 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Tue, 14 Apr 2026 21:16:17 +0200 Subject: [PATCH 1/6] feat: add AuthenticationContext ambient primitive Add celeste.authentication_context module exposing a frozen pydantic AuthenticationContext keyed by (modality, operation), an authentication_scope() context manager, and resolve_authentication() consumed by create_client(). Wire ambient resolution into create_client(): when no explicit auth= or api_key= was passed and the operation is known, fall back to the bound AuthenticationContext before the existing BYOA / env-credential branches. Explicit kwargs still win. Add MissingAuthenticationError (subclass of CredentialsError): raised by resolve_authentication() when a scope is bound but has no entry for the requested (modality, operation). Distinct from MissingCredentialsError so multi-tenant callers can distinguish scoped-but-uncovered from env-missing. Re-export AuthenticationContext, authentication_scope, and MissingAuthenticationError from the top-level celeste package. Backward compatible: module-level celeste.text.generate(...) API is unchanged; namespaces stay stateless singletons; existing explicit auth= callers see identical behavior. --- src/celeste/__init__.py | 14 ++ src/celeste/authentication_context.py | 78 +++++++ src/celeste/exceptions.py | 14 ++ .../unit_tests/test_authentication_context.py | 213 ++++++++++++++++++ 4 files changed, 319 insertions(+) create mode 100644 src/celeste/authentication_context.py create mode 100644 tests/unit_tests/test_authentication_context.py diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index d2e78ef..25f7d2e 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -7,6 +7,11 @@ from celeste import providers as _providers # noqa: F401 from celeste.auth import APIKey, Authentication, AuthHeader, NoAuth +from celeste.authentication_context import ( + AuthenticationContext, + authentication_scope, + resolve_authentication, +) from celeste.client import ModalityClient from celeste.core import ( Capability, @@ -19,6 +24,7 @@ from celeste.exceptions import ( ClientNotFoundError, Error, + MissingAuthenticationError, ModelNotFoundError, ) from celeste.io import Input, Output, Usage @@ -248,6 +254,11 @@ def create_client( raise ClientNotFoundError(modality=resolved_modality, provider=target) modality_client_class = _CLIENT_MAP[(resolved_modality, target)] + # Ambient fallback: only when neither auth nor api_key was passed and the + # operation is known. Explicit kwargs always win. + if auth is None and api_key is None and resolved_operation is not None: + auth = resolve_authentication(resolved_modality, resolved_operation) + # Auth resolution: BYOA for protocol path, credentials for provider path if resolved_protocol is not None and resolved_provider is None: if auth is not None: @@ -276,12 +287,14 @@ def create_client( __all__ = [ "APIKey", "Authentication", + "AuthenticationContext", "Capability", "CodeExecution", "Content", "Error", "Input", "Message", + "MissingAuthenticationError", "Modality", "Model", "Operation", @@ -297,6 +310,7 @@ def create_client( "WebSearch", "XSearch", "audio", + "authentication_scope", "create_client", "documents", "get_model", diff --git a/src/celeste/authentication_context.py b/src/celeste/authentication_context.py new file mode 100644 index 0000000..44d6551 --- /dev/null +++ b/src/celeste/authentication_context.py @@ -0,0 +1,78 @@ +"""Ambient authentication scope keyed by (modality, operation).""" + +from __future__ import annotations + +from collections.abc import Iterator, Mapping +from contextlib import contextmanager +from contextvars import ContextVar + +from pydantic import BaseModel, ConfigDict + +from celeste.auth import Authentication +from celeste.core import Modality, Operation +from celeste.exceptions import MissingAuthenticationError + + +class AuthenticationContext(BaseModel): + """Per-(modality, operation) authentication bound to an async scope.""" + + # Frozen: asyncio.gather siblings share the context snapshot by reference, + # so mutability would cause silent cross-task credential bleed. + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + + entries: Mapping[tuple[Modality, Operation], Authentication | None] + + def get_for( + self, modality: Modality, operation: Operation + ) -> Authentication | None: + """Return the authentication for a given (modality, operation).""" + return self.entries.get((modality, operation)) + + +_current_context: ContextVar[AuthenticationContext | None] = ContextVar( + "celeste.authentication_context.current", + default=None, +) + + +@contextmanager +def authentication_scope( + context: AuthenticationContext | None, +) -> Iterator[None]: + """Bind an authentication context for the current async scope. + + Within the ``with`` block, calls to ``celeste..(...)`` + that don't pass an explicit ``auth=`` or ``api_key=`` resolve their + authentication from the bound context. + """ + token = _current_context.set(context) + try: + yield + finally: + _current_context.reset(token) + + +def resolve_authentication( + modality: Modality, operation: Operation +) -> Authentication | None: + """Look up (modality, operation) in the current ambient context. + + Returns ``None`` when no scope is bound (caller may fall back to env). + Raises ``MissingAuthenticationError`` when a scope is bound but has no + authentication for the requested (modality, operation) — the caller + explicitly scoped auth and this slot is uncovered. + """ + context = _current_context.get() + if context is None: + return None + auth = context.get_for(modality, operation) + if auth is None: + raise MissingAuthenticationError(modality=modality, operation=operation) + return auth + + +__all__ = [ + "AuthenticationContext", + "authentication_scope", + "resolve_authentication", +] diff --git a/src/celeste/exceptions.py b/src/celeste/exceptions.py index 83f74ac..1f6446b 100644 --- a/src/celeste/exceptions.py +++ b/src/celeste/exceptions.py @@ -2,6 +2,8 @@ from typing import Any +from celeste.core import Modality, Operation + class Error(Exception): """Base exception for all Celeste errors.""" @@ -233,6 +235,17 @@ def __init__(self, provider: str) -> None: ) +class MissingAuthenticationError(CredentialsError): + """Raised when authentication cannot be resolved for a (modality, operation).""" + + def __init__(self, *, modality: Modality, operation: Operation) -> None: + self.modality = modality + self.operation = operation + super().__init__( + f"No authentication configured for {modality.value}/{operation.value}" + ) + + class InvalidToolError(ValidationError): """Raised when a tool item is not a Tool instance or dict.""" @@ -263,6 +276,7 @@ class UnsupportedParameterWarning(UserWarning): "ConstraintViolationError", "Error", "InvalidToolError", + "MissingAuthenticationError", "MissingCredentialsError", "MissingDependencyError", "ModalityNotFoundError", diff --git a/tests/unit_tests/test_authentication_context.py b/tests/unit_tests/test_authentication_context.py new file mode 100644 index 0000000..70b36bf --- /dev/null +++ b/tests/unit_tests/test_authentication_context.py @@ -0,0 +1,213 @@ +"""Tests for the ambient authentication context primitive.""" + +from __future__ import annotations + +import asyncio +import contextvars +from concurrent.futures import ThreadPoolExecutor + +import pytest +from pydantic import SecretStr, ValidationError + +from celeste.auth import AuthHeader, NoAuth +from celeste.authentication_context import ( + AuthenticationContext, + _current_context, + authentication_scope, + resolve_authentication, +) +from celeste.core import Modality, Operation +from celeste.exceptions import MissingAuthenticationError + + +@pytest.fixture +def text_auth() -> AuthHeader: + return AuthHeader(secret=SecretStr("text-key")) + + +@pytest.fixture +def images_auth() -> AuthHeader: + return AuthHeader(secret=SecretStr("images-key")) + + +@pytest.fixture +def audio_auth() -> NoAuth: + return NoAuth() + + +@pytest.fixture +def full_context( + text_auth: AuthHeader, + images_auth: AuthHeader, + audio_auth: NoAuth, +) -> AuthenticationContext: + return AuthenticationContext( + entries={ + (Modality.TEXT, Operation.GENERATE): text_auth, + (Modality.IMAGES, Operation.GENERATE): images_auth, + (Modality.AUDIO, Operation.SPEAK): audio_auth, + } + ) + + +class TestAuthenticationContextModel: + def test_frozen_rejects_mutation(self, full_context: AuthenticationContext) -> None: + with pytest.raises(ValidationError): + full_context.entries = {} # type: ignore[misc] + + def test_get_for_returns_bound_entry( + self, full_context: AuthenticationContext, text_auth: AuthHeader + ) -> None: + assert full_context.get_for(Modality.TEXT, Operation.GENERATE) is text_auth + + def test_get_for_missing_entry_returns_none( + self, full_context: AuthenticationContext + ) -> None: + assert full_context.get_for(Modality.VIDEOS, Operation.GENERATE) is None + + +class TestAuthenticationScope: + def test_outside_scope_resolves_none(self) -> None: + assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None + + def test_inside_scope_resolves_bound_auth( + self, full_context: AuthenticationContext, text_auth: AuthHeader + ) -> None: + with authentication_scope(full_context): + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth + ) + + def test_exit_restores_previous_state( + self, full_context: AuthenticationContext + ) -> None: + assert _current_context.get() is None + with authentication_scope(full_context): + assert _current_context.get() is full_context + assert _current_context.get() is None + + def test_nested_scopes_inner_wins_outer_restored( + self, + full_context: AuthenticationContext, + text_auth: AuthHeader, + ) -> None: + inner_auth = AuthHeader(secret=SecretStr("inner-key")) + inner_context = AuthenticationContext( + entries={(Modality.TEXT, Operation.GENERATE): inner_auth} + ) + + with authentication_scope(full_context): + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth + ) + with authentication_scope(inner_context): + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) + is inner_auth + ) + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth + ) + + def test_scope_with_none_clears_binding( + self, full_context: AuthenticationContext + ) -> None: + with authentication_scope(full_context): + with authentication_scope(None): + assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None + assert _current_context.get() is full_context + + +class TestResolveAuthentication: + def test_scope_bound_but_entry_missing_raises( + self, full_context: AuthenticationContext + ) -> None: + with ( + authentication_scope(full_context), + pytest.raises(MissingAuthenticationError) as exc_info, + ): + resolve_authentication(Modality.VIDEOS, Operation.GENERATE) + err = exc_info.value + assert err.modality is Modality.VIDEOS + assert err.operation is Operation.GENERATE + assert "videos/generate" in str(err) + + def test_scope_bound_with_explicit_none_raises(self) -> None: + context = AuthenticationContext( + entries={(Modality.TEXT, Operation.GENERATE): None} + ) + with ( + authentication_scope(context), + pytest.raises(MissingAuthenticationError), + ): + resolve_authentication(Modality.TEXT, Operation.GENERATE) + + +class TestAsyncPropagation: + @pytest.mark.asyncio + async def test_create_task_propagates( + self, full_context: AuthenticationContext, images_auth: AuthHeader + ) -> None: + async def child() -> object: + return resolve_authentication(Modality.IMAGES, Operation.GENERATE) + + with authentication_scope(full_context): + task = asyncio.create_task(child()) + result = await task + + assert result is images_auth + + @pytest.mark.asyncio + async def test_gather_siblings_share_snapshot( + self, full_context: AuthenticationContext, images_auth: AuthHeader + ) -> None: + async def child() -> object: + return resolve_authentication(Modality.IMAGES, Operation.GENERATE) + + with authentication_scope(full_context): + results = await asyncio.gather(child(), child(), child()) + + assert all(r is images_auth for r in results) + + @pytest.mark.asyncio + async def test_to_thread_propagates( + self, full_context: AuthenticationContext, text_auth: AuthHeader + ) -> None: + def sync_worker() -> object: + return resolve_authentication(Modality.TEXT, Operation.GENERATE) + + with authentication_scope(full_context): + result = await asyncio.to_thread(sync_worker) + + assert result is text_auth + + def test_raw_thread_pool_does_not_propagate( + self, full_context: AuthenticationContext + ) -> None: + def sync_worker() -> object: + return resolve_authentication(Modality.TEXT, Operation.GENERATE) + + with ( + ThreadPoolExecutor(max_workers=1) as pool, + authentication_scope(full_context), + ): + future = pool.submit(sync_worker) + result = future.result() + + assert result is None + + def test_raw_thread_pool_with_copy_context_propagates( + self, full_context: AuthenticationContext, text_auth: AuthHeader + ) -> None: + def sync_worker() -> object: + return resolve_authentication(Modality.TEXT, Operation.GENERATE) + + with ( + ThreadPoolExecutor(max_workers=1) as pool, + authentication_scope(full_context), + ): + ctx_snapshot = contextvars.copy_context() + future = pool.submit(ctx_snapshot.run, sync_worker) + result = future.result() + + assert result is text_auth From cb7476233b241ca6c2af8930481474729d1df8cc Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Tue, 14 Apr 2026 21:35:25 +0200 Subject: [PATCH 2/6] style: drop unnecessary `from __future__ import annotations` celeste-python targets Python 3.12+, where PEP 604 union syntax (`X | Y`) and PEP 585 generic collections (`list[int]`, `tuple[...]`) evaluate natively. The future import is not used anywhere else in celeste-python (0/332 files) and introduces friction with pydantic v2 forward-reference resolution. Remove from the new authentication_context module and its test file. --- src/celeste/authentication_context.py | 2 -- tests/unit_tests/test_authentication_context.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/src/celeste/authentication_context.py b/src/celeste/authentication_context.py index 44d6551..322246b 100644 --- a/src/celeste/authentication_context.py +++ b/src/celeste/authentication_context.py @@ -1,7 +1,5 @@ """Ambient authentication scope keyed by (modality, operation).""" -from __future__ import annotations - from collections.abc import Iterator, Mapping from contextlib import contextmanager from contextvars import ContextVar diff --git a/tests/unit_tests/test_authentication_context.py b/tests/unit_tests/test_authentication_context.py index 70b36bf..71ffd5e 100644 --- a/tests/unit_tests/test_authentication_context.py +++ b/tests/unit_tests/test_authentication_context.py @@ -1,7 +1,5 @@ """Tests for the ambient authentication context primitive.""" -from __future__ import annotations - import asyncio import contextvars from concurrent.futures import ThreadPoolExecutor From 5cc1cafd72d0c37547007257acb9949113b9913f Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Tue, 14 Apr 2026 21:57:15 +0200 Subject: [PATCH 3/6] refactor: drop AuthenticationContext.get_for wrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Eliminate the one-line wrapper method in favor of direct dict access on the frozen entries mapping. resolve_authentication now calls context.entries.get((modality, operation)) directly, which removes the public method and sidesteps the naming-convention question about get_for vs the celeste verb_noun house style. Rewrite two tests that were reaching into _current_context.get() to verify scope state through resolve_authentication() instead — same invariants, no coupling to module internals. Drop the two trivial get_for tests that were only verifying dict.get semantics; the resolution tests already cover the same ground end-to-end. Trim the resolve_authentication docstring to a one-line summary plus a Raises block, matching the celeste docstring style. --- src/celeste/authentication_context.py | 15 +++------- .../unit_tests/test_authentication_context.py | 30 ++++++++----------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/src/celeste/authentication_context.py b/src/celeste/authentication_context.py index 322246b..f549b8a 100644 --- a/src/celeste/authentication_context.py +++ b/src/celeste/authentication_context.py @@ -20,12 +20,6 @@ class AuthenticationContext(BaseModel): entries: Mapping[tuple[Modality, Operation], Authentication | None] - def get_for( - self, modality: Modality, operation: Operation - ) -> Authentication | None: - """Return the authentication for a given (modality, operation).""" - return self.entries.get((modality, operation)) - _current_context: ContextVar[AuthenticationContext | None] = ContextVar( "celeste.authentication_context.current", @@ -55,15 +49,14 @@ def resolve_authentication( ) -> Authentication | None: """Look up (modality, operation) in the current ambient context. - Returns ``None`` when no scope is bound (caller may fall back to env). - Raises ``MissingAuthenticationError`` when a scope is bound but has no - authentication for the requested (modality, operation) — the caller - explicitly scoped auth and this slot is uncovered. + Raises: + MissingAuthenticationError: A scope is bound but has no authentication + for the requested (modality, operation). """ context = _current_context.get() if context is None: return None - auth = context.get_for(modality, operation) + auth = context.entries.get((modality, operation)) if auth is None: raise MissingAuthenticationError(modality=modality, operation=operation) return auth diff --git a/tests/unit_tests/test_authentication_context.py b/tests/unit_tests/test_authentication_context.py index 71ffd5e..73587b2 100644 --- a/tests/unit_tests/test_authentication_context.py +++ b/tests/unit_tests/test_authentication_context.py @@ -10,7 +10,6 @@ from celeste.auth import AuthHeader, NoAuth from celeste.authentication_context import ( AuthenticationContext, - _current_context, authentication_scope, resolve_authentication, ) @@ -53,16 +52,6 @@ def test_frozen_rejects_mutation(self, full_context: AuthenticationContext) -> N with pytest.raises(ValidationError): full_context.entries = {} # type: ignore[misc] - def test_get_for_returns_bound_entry( - self, full_context: AuthenticationContext, text_auth: AuthHeader - ) -> None: - assert full_context.get_for(Modality.TEXT, Operation.GENERATE) is text_auth - - def test_get_for_missing_entry_returns_none( - self, full_context: AuthenticationContext - ) -> None: - assert full_context.get_for(Modality.VIDEOS, Operation.GENERATE) is None - class TestAuthenticationScope: def test_outside_scope_resolves_none(self) -> None: @@ -77,12 +66,14 @@ def test_inside_scope_resolves_bound_auth( ) def test_exit_restores_previous_state( - self, full_context: AuthenticationContext + self, full_context: AuthenticationContext, text_auth: AuthHeader ) -> None: - assert _current_context.get() is None + assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None with authentication_scope(full_context): - assert _current_context.get() is full_context - assert _current_context.get() is None + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth + ) + assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None def test_nested_scopes_inner_wins_outer_restored( self, @@ -108,12 +99,17 @@ def test_nested_scopes_inner_wins_outer_restored( ) def test_scope_with_none_clears_binding( - self, full_context: AuthenticationContext + self, full_context: AuthenticationContext, text_auth: AuthHeader ) -> None: with authentication_scope(full_context): + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth + ) with authentication_scope(None): assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None - assert _current_context.get() is full_context + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth + ) class TestResolveAuthentication: From dcf8236ada709f56c839ae0186f819e3e2ac341b Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Tue, 14 Apr 2026 22:08:09 +0200 Subject: [PATCH 4/6] refactor: merge authentication_context module into auth.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consolidate the ambient-authentication primitive into auth.py where it lives naturally alongside Authentication, AuthHeader, NoAuth, and the existing auth registry. Delete the standalone authentication_context.py (was 70 lines) — there is no meaningful "one concern per module" split between the authentication types and the ambient scope machinery that references them. Merge the corresponding test_authentication_context.py tests into test_auth.py so test layout mirrors source layout. The existing clean_auth_registry autouse fixture is harmless for the new tests. celeste.__init__ and __all__ now import AuthenticationContext, authentication_scope, and resolve_authentication directly from celeste.auth. Public API surface is unchanged — top-level imports still work via `from celeste import authentication_scope, AuthenticationContext`. --- src/celeste/__init__.py | 7 +- src/celeste/auth.py | 62 ++++- src/celeste/authentication_context.py | 69 ------ tests/unit_tests/test_auth.py | 211 +++++++++++++++++- .../unit_tests/test_authentication_context.py | 207 ----------------- 5 files changed, 275 insertions(+), 281 deletions(-) delete mode 100644 src/celeste/authentication_context.py delete mode 100644 tests/unit_tests/test_authentication_context.py diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index 25f7d2e..3c61e07 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -6,9 +6,12 @@ from pydantic import SecretStr from celeste import providers as _providers # noqa: F401 -from celeste.auth import APIKey, Authentication, AuthHeader, NoAuth -from celeste.authentication_context import ( +from celeste.auth import ( + APIKey, + Authentication, AuthenticationContext, + AuthHeader, + NoAuth, authentication_scope, resolve_authentication, ) diff --git a/src/celeste/auth.py b/src/celeste/auth.py index ce127bc..7b2cd08 100644 --- a/src/celeste/auth.py +++ b/src/celeste/auth.py @@ -1,8 +1,14 @@ """Authentication methods for Celeste providers.""" from abc import ABC, abstractmethod +from collections.abc import Iterator, Mapping +from contextlib import contextmanager +from contextvars import ContextVar -from pydantic import BaseModel, SecretStr, field_validator +from pydantic import BaseModel, ConfigDict, SecretStr, field_validator + +from celeste.core import Modality, Operation +from celeste.exceptions import MissingAuthenticationError # Module-level registry (same pattern as _clients and _models) _auth_classes: dict[str, type["Authentication"]] = {} @@ -86,11 +92,65 @@ def get_auth_class(auth_type: str) -> type[Authentication]: return _auth_classes[auth_type] +class AuthenticationContext(BaseModel): + """Per-(modality, operation) authentication bound to an async scope.""" + + # Frozen: asyncio.gather siblings share the context snapshot by reference, + # so mutability would cause silent cross-task credential bleed. + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + + entries: Mapping[tuple[Modality, Operation], Authentication | None] + + +_current_context: ContextVar[AuthenticationContext | None] = ContextVar( + "celeste.auth.current_context", + default=None, +) + + +@contextmanager +def authentication_scope( + context: AuthenticationContext | None, +) -> Iterator[None]: + """Bind an authentication context for the current async scope. + + Within the ``with`` block, calls to ``celeste..(...)`` + that don't pass an explicit ``auth=`` or ``api_key=`` resolve their + authentication from the bound context. + """ + token = _current_context.set(context) + try: + yield + finally: + _current_context.reset(token) + + +def resolve_authentication( + modality: Modality, operation: Operation +) -> Authentication | None: + """Look up (modality, operation) in the current ambient context. + + Raises: + MissingAuthenticationError: A scope is bound but has no authentication + for the requested (modality, operation). + """ + context = _current_context.get() + if context is None: + return None + auth = context.entries.get((modality, operation)) + if auth is None: + raise MissingAuthenticationError(modality=modality, operation=operation) + return auth + + __all__ = [ "APIKey", "AuthHeader", "Authentication", + "AuthenticationContext", "NoAuth", + "authentication_scope", "get_auth_class", "register_auth", + "resolve_authentication", ] diff --git a/src/celeste/authentication_context.py b/src/celeste/authentication_context.py deleted file mode 100644 index f549b8a..0000000 --- a/src/celeste/authentication_context.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Ambient authentication scope keyed by (modality, operation).""" - -from collections.abc import Iterator, Mapping -from contextlib import contextmanager -from contextvars import ContextVar - -from pydantic import BaseModel, ConfigDict - -from celeste.auth import Authentication -from celeste.core import Modality, Operation -from celeste.exceptions import MissingAuthenticationError - - -class AuthenticationContext(BaseModel): - """Per-(modality, operation) authentication bound to an async scope.""" - - # Frozen: asyncio.gather siblings share the context snapshot by reference, - # so mutability would cause silent cross-task credential bleed. - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) - - entries: Mapping[tuple[Modality, Operation], Authentication | None] - - -_current_context: ContextVar[AuthenticationContext | None] = ContextVar( - "celeste.authentication_context.current", - default=None, -) - - -@contextmanager -def authentication_scope( - context: AuthenticationContext | None, -) -> Iterator[None]: - """Bind an authentication context for the current async scope. - - Within the ``with`` block, calls to ``celeste..(...)`` - that don't pass an explicit ``auth=`` or ``api_key=`` resolve their - authentication from the bound context. - """ - token = _current_context.set(context) - try: - yield - finally: - _current_context.reset(token) - - -def resolve_authentication( - modality: Modality, operation: Operation -) -> Authentication | None: - """Look up (modality, operation) in the current ambient context. - - Raises: - MissingAuthenticationError: A scope is bound but has no authentication - for the requested (modality, operation). - """ - context = _current_context.get() - if context is None: - return None - auth = context.entries.get((modality, operation)) - if auth is None: - raise MissingAuthenticationError(modality=modality, operation=operation) - return auth - - -__all__ = [ - "AuthenticationContext", - "authentication_scope", - "resolve_authentication", -] diff --git a/tests/unit_tests/test_auth.py b/tests/unit_tests/test_auth.py index 3212e72..bde9108 100644 --- a/tests/unit_tests/test_auth.py +++ b/tests/unit_tests/test_auth.py @@ -1,17 +1,26 @@ -"""Tests for authentication classes and registry.""" +"""Tests for authentication classes, registry, and ambient scope.""" +import asyncio +import contextvars from collections.abc import Generator +from concurrent.futures import ThreadPoolExecutor import pytest -from pydantic import SecretStr +from pydantic import SecretStr, ValidationError from celeste.auth import ( APIKey, Authentication, + AuthenticationContext, AuthHeader, + NoAuth, + authentication_scope, get_auth_class, register_auth, + resolve_authentication, ) +from celeste.core import Modality, Operation +from celeste.exceptions import MissingAuthenticationError @pytest.fixture(autouse=True) @@ -100,3 +109,201 @@ def test_get_auth_class_with_unknown_type_raises(self) -> None: ValueError, match=r"Unknown auth type: nonexistent.*Available:" ): get_auth_class("nonexistent") + + +@pytest.fixture +def text_auth() -> AuthHeader: + return AuthHeader(secret=SecretStr("text-key")) + + +@pytest.fixture +def images_auth() -> AuthHeader: + return AuthHeader(secret=SecretStr("images-key")) + + +@pytest.fixture +def audio_auth() -> NoAuth: + return NoAuth() + + +@pytest.fixture +def full_context( + text_auth: AuthHeader, + images_auth: AuthHeader, + audio_auth: NoAuth, +) -> AuthenticationContext: + return AuthenticationContext( + entries={ + (Modality.TEXT, Operation.GENERATE): text_auth, + (Modality.IMAGES, Operation.GENERATE): images_auth, + (Modality.AUDIO, Operation.SPEAK): audio_auth, + } + ) + + +class TestAuthenticationContext: + """Test AuthenticationContext pydantic model.""" + + def test_frozen_rejects_mutation(self, full_context: AuthenticationContext) -> None: + with pytest.raises(ValidationError): + full_context.entries = {} # type: ignore[misc] + + +class TestAuthenticationScope: + """Test authentication_scope context manager.""" + + def test_outside_scope_resolves_none(self) -> None: + assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None + + def test_inside_scope_resolves_bound_auth( + self, full_context: AuthenticationContext, text_auth: AuthHeader + ) -> None: + with authentication_scope(full_context): + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth + ) + + def test_exit_restores_previous_state( + self, full_context: AuthenticationContext, text_auth: AuthHeader + ) -> None: + assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None + with authentication_scope(full_context): + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth + ) + assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None + + def test_nested_scopes_inner_wins_outer_restored( + self, + full_context: AuthenticationContext, + text_auth: AuthHeader, + ) -> None: + inner_auth = AuthHeader(secret=SecretStr("inner-key")) + inner_context = AuthenticationContext( + entries={(Modality.TEXT, Operation.GENERATE): inner_auth} + ) + + with authentication_scope(full_context): + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth + ) + with authentication_scope(inner_context): + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) + is inner_auth + ) + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth + ) + + def test_scope_with_none_clears_binding( + self, full_context: AuthenticationContext, text_auth: AuthHeader + ) -> None: + with authentication_scope(full_context): + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth + ) + with authentication_scope(None): + assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None + assert ( + resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth + ) + + +class TestResolveAuthentication: + """Test resolve_authentication lookup helper.""" + + def test_scope_bound_but_entry_missing_raises( + self, full_context: AuthenticationContext + ) -> None: + with ( + authentication_scope(full_context), + pytest.raises(MissingAuthenticationError) as exc_info, + ): + resolve_authentication(Modality.VIDEOS, Operation.GENERATE) + err = exc_info.value + assert err.modality is Modality.VIDEOS + assert err.operation is Operation.GENERATE + assert "videos/generate" in str(err) + + def test_scope_bound_with_explicit_none_raises(self) -> None: + context = AuthenticationContext( + entries={(Modality.TEXT, Operation.GENERATE): None} + ) + with ( + authentication_scope(context), + pytest.raises(MissingAuthenticationError), + ): + resolve_authentication(Modality.TEXT, Operation.GENERATE) + + +class TestAsyncPropagation: + """Test ContextVar propagation across async boundaries.""" + + @pytest.mark.asyncio + async def test_create_task_propagates( + self, full_context: AuthenticationContext, images_auth: AuthHeader + ) -> None: + async def child() -> object: + return resolve_authentication(Modality.IMAGES, Operation.GENERATE) + + with authentication_scope(full_context): + task = asyncio.create_task(child()) + result = await task + + assert result is images_auth + + @pytest.mark.asyncio + async def test_gather_siblings_share_snapshot( + self, full_context: AuthenticationContext, images_auth: AuthHeader + ) -> None: + async def child() -> object: + return resolve_authentication(Modality.IMAGES, Operation.GENERATE) + + with authentication_scope(full_context): + results = await asyncio.gather(child(), child(), child()) + + assert all(r is images_auth for r in results) + + @pytest.mark.asyncio + async def test_to_thread_propagates( + self, full_context: AuthenticationContext, text_auth: AuthHeader + ) -> None: + def sync_worker() -> object: + return resolve_authentication(Modality.TEXT, Operation.GENERATE) + + with authentication_scope(full_context): + result = await asyncio.to_thread(sync_worker) + + assert result is text_auth + + def test_raw_thread_pool_does_not_propagate( + self, full_context: AuthenticationContext + ) -> None: + def sync_worker() -> object: + return resolve_authentication(Modality.TEXT, Operation.GENERATE) + + with ( + ThreadPoolExecutor(max_workers=1) as pool, + authentication_scope(full_context), + ): + future = pool.submit(sync_worker) + result = future.result() + + assert result is None + + def test_raw_thread_pool_with_copy_context_propagates( + self, full_context: AuthenticationContext, text_auth: AuthHeader + ) -> None: + def sync_worker() -> object: + return resolve_authentication(Modality.TEXT, Operation.GENERATE) + + with ( + ThreadPoolExecutor(max_workers=1) as pool, + authentication_scope(full_context), + ): + ctx_snapshot = contextvars.copy_context() + future = pool.submit(ctx_snapshot.run, sync_worker) + result = future.result() + + assert result is text_auth diff --git a/tests/unit_tests/test_authentication_context.py b/tests/unit_tests/test_authentication_context.py deleted file mode 100644 index 73587b2..0000000 --- a/tests/unit_tests/test_authentication_context.py +++ /dev/null @@ -1,207 +0,0 @@ -"""Tests for the ambient authentication context primitive.""" - -import asyncio -import contextvars -from concurrent.futures import ThreadPoolExecutor - -import pytest -from pydantic import SecretStr, ValidationError - -from celeste.auth import AuthHeader, NoAuth -from celeste.authentication_context import ( - AuthenticationContext, - authentication_scope, - resolve_authentication, -) -from celeste.core import Modality, Operation -from celeste.exceptions import MissingAuthenticationError - - -@pytest.fixture -def text_auth() -> AuthHeader: - return AuthHeader(secret=SecretStr("text-key")) - - -@pytest.fixture -def images_auth() -> AuthHeader: - return AuthHeader(secret=SecretStr("images-key")) - - -@pytest.fixture -def audio_auth() -> NoAuth: - return NoAuth() - - -@pytest.fixture -def full_context( - text_auth: AuthHeader, - images_auth: AuthHeader, - audio_auth: NoAuth, -) -> AuthenticationContext: - return AuthenticationContext( - entries={ - (Modality.TEXT, Operation.GENERATE): text_auth, - (Modality.IMAGES, Operation.GENERATE): images_auth, - (Modality.AUDIO, Operation.SPEAK): audio_auth, - } - ) - - -class TestAuthenticationContextModel: - def test_frozen_rejects_mutation(self, full_context: AuthenticationContext) -> None: - with pytest.raises(ValidationError): - full_context.entries = {} # type: ignore[misc] - - -class TestAuthenticationScope: - def test_outside_scope_resolves_none(self) -> None: - assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None - - def test_inside_scope_resolves_bound_auth( - self, full_context: AuthenticationContext, text_auth: AuthHeader - ) -> None: - with authentication_scope(full_context): - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth - ) - - def test_exit_restores_previous_state( - self, full_context: AuthenticationContext, text_auth: AuthHeader - ) -> None: - assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None - with authentication_scope(full_context): - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth - ) - assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None - - def test_nested_scopes_inner_wins_outer_restored( - self, - full_context: AuthenticationContext, - text_auth: AuthHeader, - ) -> None: - inner_auth = AuthHeader(secret=SecretStr("inner-key")) - inner_context = AuthenticationContext( - entries={(Modality.TEXT, Operation.GENERATE): inner_auth} - ) - - with authentication_scope(full_context): - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth - ) - with authentication_scope(inner_context): - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) - is inner_auth - ) - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth - ) - - def test_scope_with_none_clears_binding( - self, full_context: AuthenticationContext, text_auth: AuthHeader - ) -> None: - with authentication_scope(full_context): - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth - ) - with authentication_scope(None): - assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth - ) - - -class TestResolveAuthentication: - def test_scope_bound_but_entry_missing_raises( - self, full_context: AuthenticationContext - ) -> None: - with ( - authentication_scope(full_context), - pytest.raises(MissingAuthenticationError) as exc_info, - ): - resolve_authentication(Modality.VIDEOS, Operation.GENERATE) - err = exc_info.value - assert err.modality is Modality.VIDEOS - assert err.operation is Operation.GENERATE - assert "videos/generate" in str(err) - - def test_scope_bound_with_explicit_none_raises(self) -> None: - context = AuthenticationContext( - entries={(Modality.TEXT, Operation.GENERATE): None} - ) - with ( - authentication_scope(context), - pytest.raises(MissingAuthenticationError), - ): - resolve_authentication(Modality.TEXT, Operation.GENERATE) - - -class TestAsyncPropagation: - @pytest.mark.asyncio - async def test_create_task_propagates( - self, full_context: AuthenticationContext, images_auth: AuthHeader - ) -> None: - async def child() -> object: - return resolve_authentication(Modality.IMAGES, Operation.GENERATE) - - with authentication_scope(full_context): - task = asyncio.create_task(child()) - result = await task - - assert result is images_auth - - @pytest.mark.asyncio - async def test_gather_siblings_share_snapshot( - self, full_context: AuthenticationContext, images_auth: AuthHeader - ) -> None: - async def child() -> object: - return resolve_authentication(Modality.IMAGES, Operation.GENERATE) - - with authentication_scope(full_context): - results = await asyncio.gather(child(), child(), child()) - - assert all(r is images_auth for r in results) - - @pytest.mark.asyncio - async def test_to_thread_propagates( - self, full_context: AuthenticationContext, text_auth: AuthHeader - ) -> None: - def sync_worker() -> object: - return resolve_authentication(Modality.TEXT, Operation.GENERATE) - - with authentication_scope(full_context): - result = await asyncio.to_thread(sync_worker) - - assert result is text_auth - - def test_raw_thread_pool_does_not_propagate( - self, full_context: AuthenticationContext - ) -> None: - def sync_worker() -> object: - return resolve_authentication(Modality.TEXT, Operation.GENERATE) - - with ( - ThreadPoolExecutor(max_workers=1) as pool, - authentication_scope(full_context), - ): - future = pool.submit(sync_worker) - result = future.result() - - assert result is None - - def test_raw_thread_pool_with_copy_context_propagates( - self, full_context: AuthenticationContext, text_auth: AuthHeader - ) -> None: - def sync_worker() -> object: - return resolve_authentication(Modality.TEXT, Operation.GENERATE) - - with ( - ThreadPoolExecutor(max_workers=1) as pool, - authentication_scope(full_context), - ): - ctx_snapshot = contextvars.copy_context() - future = pool.submit(ctx_snapshot.run, sync_worker) - result = future.result() - - assert result is text_auth From c15ef29de63e41dcc457878558ca0db9c1b08018 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Wed, 15 Apr 2026 10:05:43 +0200 Subject: [PATCH 5/6] docs: complete docstrings for ambient auth API Add Args/Returns blocks to resolve_authentication and authentication_scope matching the celeste convention (full Args/Returns/Raises blocks for functions with parameters, see Credentials.get_auth and get_auth_class). Update create_client docstring to document the new ambient resolution fallback path on the auth parameter, and add MissingAuthenticationError to the Raises block. --- src/celeste/__init__.py | 5 +++++ src/celeste/auth.py | 11 +++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index 3c61e07..9c91db6 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -196,6 +196,9 @@ def create_client( model: Model object, string model ID, or None for auto-selection. api_key: Optional API key override (string or SecretStr). auth: Optional Authentication object for custom auth (e.g., GoogleADC). + When None and api_key is also None, falls back to the ambient + AuthenticationContext bound by ``authentication_scope(...)`` for + the (modality, operation) pair, before the env-credential path. protocol: Wire format protocol for compatible APIs (e.g., "openresponses", "chatcompletions"). Use with base_url for third-party compatible APIs. base_url: Custom base URL override. Use with protocol for compatible APIs, @@ -208,6 +211,8 @@ def create_client( ModelNotFoundError: If no model found for the specified capability/provider. ClientNotFoundError: If no client registered for capability/provider/protocol. MissingCredentialsError: If required credentials are not configured. + MissingAuthenticationError: If an ambient AuthenticationContext is bound + but has no entry for the requested (modality, operation). ValueError: If capability/operation cannot be inferred from model. """ # Translation layer: convert deprecated capability to modality/operation diff --git a/src/celeste/auth.py b/src/celeste/auth.py index 7b2cd08..2ad84ff 100644 --- a/src/celeste/auth.py +++ b/src/celeste/auth.py @@ -117,6 +117,10 @@ def authentication_scope( Within the ``with`` block, calls to ``celeste..(...)`` that don't pass an explicit ``auth=`` or ``api_key=`` resolve their authentication from the bound context. + + Args: + context: The AuthenticationContext to bind, or None to clear any + outer binding for the duration of the block. """ token = _current_context.set(context) try: @@ -130,6 +134,13 @@ def resolve_authentication( ) -> Authentication | None: """Look up (modality, operation) in the current ambient context. + Args: + modality: The modality to look up. + operation: The operation to look up. + + Returns: + The bound Authentication, or None if no ambient context is active. + Raises: MissingAuthenticationError: A scope is bound but has no authentication for the requested (modality, operation). From 5dd971f29ad700ef9a20114e2afc585a28e2a00b Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Wed, 15 Apr 2026 10:36:12 +0200 Subject: [PATCH 6/6] refactor: rekey AuthenticationContext by Provider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switch the ambient context from per-(Modality, Operation) keying to per-Provider keying. Provider keying matches celeste's existing internal auth model (Credentials.get_auth(provider, ...) and the _auth_registry in credentials.py are both per-Provider), maps directly to user mental models ("my OpenAI key, my Anthropic key"), and eliminates redundancy when one provider serves multiple modalities. create_client() now consults the ambient context with resolved_provider (known immediately after _resolve_model). The lookup happens in the same location as before, just with a different key. MissingAuthenticationError now takes a single positional Provider, mirroring MissingCredentialsError's positional-provider style. AuthenticationContext.entries is now Mapping[Provider, Authentication] (no None values — missing keys signal "no auth," consistent with celeste's internal registry behavior). Tests rewritten with provider-named fixtures (openai_auth, anthropic_auth, elevenlabs_auth) keyed by Provider.OPENAI / .ANTHROPIC / .ELEVENLABS. The redundant explicit-None test is dropped since the value type no longer admits None — the missing-key case covers it. Same-provider edge cases (e.g., one user with both a Gemini API key and a Vertex OAuth token) are handled by the existing explicit auth= kwarg escape hatch on each celeste call. --- src/celeste/__init__.py | 10 +-- src/celeste/auth.py | 23 +++---- src/celeste/exceptions.py | 13 ++-- tests/unit_tests/test_auth.py | 118 +++++++++++++--------------------- 4 files changed, 65 insertions(+), 99 deletions(-) diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index 9c91db6..5663e1c 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -198,7 +198,7 @@ def create_client( auth: Optional Authentication object for custom auth (e.g., GoogleADC). When None and api_key is also None, falls back to the ambient AuthenticationContext bound by ``authentication_scope(...)`` for - the (modality, operation) pair, before the env-credential path. + the resolved model's provider, before the env-credential path. protocol: Wire format protocol for compatible APIs (e.g., "openresponses", "chatcompletions"). Use with base_url for third-party compatible APIs. base_url: Custom base URL override. Use with protocol for compatible APIs, @@ -212,7 +212,7 @@ def create_client( ClientNotFoundError: If no client registered for capability/provider/protocol. MissingCredentialsError: If required credentials are not configured. MissingAuthenticationError: If an ambient AuthenticationContext is bound - but has no entry for the requested (modality, operation). + but has no entry for the resolved model's provider. ValueError: If capability/operation cannot be inferred from model. """ # Translation layer: convert deprecated capability to modality/operation @@ -263,9 +263,9 @@ def create_client( modality_client_class = _CLIENT_MAP[(resolved_modality, target)] # Ambient fallback: only when neither auth nor api_key was passed and the - # operation is known. Explicit kwargs always win. - if auth is None and api_key is None and resolved_operation is not None: - auth = resolve_authentication(resolved_modality, resolved_operation) + # provider is known. Explicit kwargs always win. + if auth is None and api_key is None and resolved_provider is not None: + auth = resolve_authentication(resolved_provider) # Auth resolution: BYOA for protocol path, credentials for provider path if resolved_protocol is not None and resolved_provider is None: diff --git a/src/celeste/auth.py b/src/celeste/auth.py index 2ad84ff..18f8208 100644 --- a/src/celeste/auth.py +++ b/src/celeste/auth.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, SecretStr, field_validator -from celeste.core import Modality, Operation +from celeste.core import Provider from celeste.exceptions import MissingAuthenticationError # Module-level registry (same pattern as _clients and _models) @@ -93,13 +93,13 @@ def get_auth_class(auth_type: str) -> type[Authentication]: class AuthenticationContext(BaseModel): - """Per-(modality, operation) authentication bound to an async scope.""" + """Per-provider authentication bound to an async scope.""" # Frozen: asyncio.gather siblings share the context snapshot by reference, # so mutability would cause silent cross-task credential bleed. model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) - entries: Mapping[tuple[Modality, Operation], Authentication | None] + entries: Mapping[Provider, Authentication] _current_context: ContextVar[AuthenticationContext | None] = ContextVar( @@ -116,7 +116,7 @@ def authentication_scope( Within the ``with`` block, calls to ``celeste..(...)`` that don't pass an explicit ``auth=`` or ``api_key=`` resolve their - authentication from the bound context. + authentication from the bound context using the resolved model's provider. Args: context: The AuthenticationContext to bind, or None to clear any @@ -129,28 +129,25 @@ def authentication_scope( _current_context.reset(token) -def resolve_authentication( - modality: Modality, operation: Operation -) -> Authentication | None: - """Look up (modality, operation) in the current ambient context. +def resolve_authentication(provider: Provider) -> Authentication | None: + """Look up the provider in the current ambient context. Args: - modality: The modality to look up. - operation: The operation to look up. + provider: The provider to look up. Returns: The bound Authentication, or None if no ambient context is active. Raises: MissingAuthenticationError: A scope is bound but has no authentication - for the requested (modality, operation). + for the requested provider. """ context = _current_context.get() if context is None: return None - auth = context.entries.get((modality, operation)) + auth = context.entries.get(provider) if auth is None: - raise MissingAuthenticationError(modality=modality, operation=operation) + raise MissingAuthenticationError(provider) return auth diff --git a/src/celeste/exceptions.py b/src/celeste/exceptions.py index 1f6446b..9b11c37 100644 --- a/src/celeste/exceptions.py +++ b/src/celeste/exceptions.py @@ -2,7 +2,7 @@ from typing import Any -from celeste.core import Modality, Operation +from celeste.core import Provider class Error(Exception): @@ -236,14 +236,11 @@ def __init__(self, provider: str) -> None: class MissingAuthenticationError(CredentialsError): - """Raised when authentication cannot be resolved for a (modality, operation).""" + """Raised when authentication cannot be resolved for a provider.""" - def __init__(self, *, modality: Modality, operation: Operation) -> None: - self.modality = modality - self.operation = operation - super().__init__( - f"No authentication configured for {modality.value}/{operation.value}" - ) + def __init__(self, provider: Provider) -> None: + self.provider = provider + super().__init__(f"No authentication configured for provider {provider.value}") class InvalidToolError(ValidationError): diff --git a/tests/unit_tests/test_auth.py b/tests/unit_tests/test_auth.py index bde9108..0714bc5 100644 --- a/tests/unit_tests/test_auth.py +++ b/tests/unit_tests/test_auth.py @@ -19,7 +19,7 @@ register_auth, resolve_authentication, ) -from celeste.core import Modality, Operation +from celeste.core import Provider from celeste.exceptions import MissingAuthenticationError @@ -112,31 +112,31 @@ def test_get_auth_class_with_unknown_type_raises(self) -> None: @pytest.fixture -def text_auth() -> AuthHeader: - return AuthHeader(secret=SecretStr("text-key")) +def openai_auth() -> AuthHeader: + return AuthHeader(secret=SecretStr("openai-key")) @pytest.fixture -def images_auth() -> AuthHeader: - return AuthHeader(secret=SecretStr("images-key")) +def anthropic_auth() -> AuthHeader: + return AuthHeader(secret=SecretStr("anthropic-key")) @pytest.fixture -def audio_auth() -> NoAuth: +def elevenlabs_auth() -> NoAuth: return NoAuth() @pytest.fixture def full_context( - text_auth: AuthHeader, - images_auth: AuthHeader, - audio_auth: NoAuth, + openai_auth: AuthHeader, + anthropic_auth: AuthHeader, + elevenlabs_auth: NoAuth, ) -> AuthenticationContext: return AuthenticationContext( entries={ - (Modality.TEXT, Operation.GENERATE): text_auth, - (Modality.IMAGES, Operation.GENERATE): images_auth, - (Modality.AUDIO, Operation.SPEAK): audio_auth, + Provider.OPENAI: openai_auth, + Provider.ANTHROPIC: anthropic_auth, + Provider.ELEVENLABS: elevenlabs_auth, } ) @@ -153,88 +153,60 @@ class TestAuthenticationScope: """Test authentication_scope context manager.""" def test_outside_scope_resolves_none(self) -> None: - assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None + assert resolve_authentication(Provider.OPENAI) is None def test_inside_scope_resolves_bound_auth( - self, full_context: AuthenticationContext, text_auth: AuthHeader + self, full_context: AuthenticationContext, openai_auth: AuthHeader ) -> None: with authentication_scope(full_context): - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth - ) + assert resolve_authentication(Provider.OPENAI) is openai_auth def test_exit_restores_previous_state( - self, full_context: AuthenticationContext, text_auth: AuthHeader + self, full_context: AuthenticationContext, openai_auth: AuthHeader ) -> None: - assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None + assert resolve_authentication(Provider.OPENAI) is None with authentication_scope(full_context): - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth - ) - assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None + assert resolve_authentication(Provider.OPENAI) is openai_auth + assert resolve_authentication(Provider.OPENAI) is None def test_nested_scopes_inner_wins_outer_restored( self, full_context: AuthenticationContext, - text_auth: AuthHeader, + openai_auth: AuthHeader, ) -> None: inner_auth = AuthHeader(secret=SecretStr("inner-key")) - inner_context = AuthenticationContext( - entries={(Modality.TEXT, Operation.GENERATE): inner_auth} - ) + inner_context = AuthenticationContext(entries={Provider.OPENAI: inner_auth}) with authentication_scope(full_context): - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth - ) + assert resolve_authentication(Provider.OPENAI) is openai_auth with authentication_scope(inner_context): - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) - is inner_auth - ) - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth - ) + assert resolve_authentication(Provider.OPENAI) is inner_auth + assert resolve_authentication(Provider.OPENAI) is openai_auth def test_scope_with_none_clears_binding( - self, full_context: AuthenticationContext, text_auth: AuthHeader + self, full_context: AuthenticationContext, openai_auth: AuthHeader ) -> None: with authentication_scope(full_context): - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth - ) + assert resolve_authentication(Provider.OPENAI) is openai_auth with authentication_scope(None): - assert resolve_authentication(Modality.TEXT, Operation.GENERATE) is None - assert ( - resolve_authentication(Modality.TEXT, Operation.GENERATE) is text_auth - ) + assert resolve_authentication(Provider.OPENAI) is None + assert resolve_authentication(Provider.OPENAI) is openai_auth class TestResolveAuthentication: """Test resolve_authentication lookup helper.""" - def test_scope_bound_but_entry_missing_raises( + def test_scope_bound_but_provider_missing_raises( self, full_context: AuthenticationContext ) -> None: with ( authentication_scope(full_context), pytest.raises(MissingAuthenticationError) as exc_info, ): - resolve_authentication(Modality.VIDEOS, Operation.GENERATE) + resolve_authentication(Provider.GOOGLE) err = exc_info.value - assert err.modality is Modality.VIDEOS - assert err.operation is Operation.GENERATE - assert "videos/generate" in str(err) - - def test_scope_bound_with_explicit_none_raises(self) -> None: - context = AuthenticationContext( - entries={(Modality.TEXT, Operation.GENERATE): None} - ) - with ( - authentication_scope(context), - pytest.raises(MissingAuthenticationError), - ): - resolve_authentication(Modality.TEXT, Operation.GENERATE) + assert err.provider is Provider.GOOGLE + assert "google" in str(err) class TestAsyncPropagation: @@ -242,46 +214,46 @@ class TestAsyncPropagation: @pytest.mark.asyncio async def test_create_task_propagates( - self, full_context: AuthenticationContext, images_auth: AuthHeader + self, full_context: AuthenticationContext, anthropic_auth: AuthHeader ) -> None: async def child() -> object: - return resolve_authentication(Modality.IMAGES, Operation.GENERATE) + return resolve_authentication(Provider.ANTHROPIC) with authentication_scope(full_context): task = asyncio.create_task(child()) result = await task - assert result is images_auth + assert result is anthropic_auth @pytest.mark.asyncio async def test_gather_siblings_share_snapshot( - self, full_context: AuthenticationContext, images_auth: AuthHeader + self, full_context: AuthenticationContext, anthropic_auth: AuthHeader ) -> None: async def child() -> object: - return resolve_authentication(Modality.IMAGES, Operation.GENERATE) + return resolve_authentication(Provider.ANTHROPIC) with authentication_scope(full_context): results = await asyncio.gather(child(), child(), child()) - assert all(r is images_auth for r in results) + assert all(r is anthropic_auth for r in results) @pytest.mark.asyncio async def test_to_thread_propagates( - self, full_context: AuthenticationContext, text_auth: AuthHeader + self, full_context: AuthenticationContext, openai_auth: AuthHeader ) -> None: def sync_worker() -> object: - return resolve_authentication(Modality.TEXT, Operation.GENERATE) + return resolve_authentication(Provider.OPENAI) with authentication_scope(full_context): result = await asyncio.to_thread(sync_worker) - assert result is text_auth + assert result is openai_auth def test_raw_thread_pool_does_not_propagate( self, full_context: AuthenticationContext ) -> None: def sync_worker() -> object: - return resolve_authentication(Modality.TEXT, Operation.GENERATE) + return resolve_authentication(Provider.OPENAI) with ( ThreadPoolExecutor(max_workers=1) as pool, @@ -293,10 +265,10 @@ def sync_worker() -> object: assert result is None def test_raw_thread_pool_with_copy_context_propagates( - self, full_context: AuthenticationContext, text_auth: AuthHeader + self, full_context: AuthenticationContext, openai_auth: AuthHeader ) -> None: def sync_worker() -> object: - return resolve_authentication(Modality.TEXT, Operation.GENERATE) + return resolve_authentication(Provider.OPENAI) with ( ThreadPoolExecutor(max_workers=1) as pool, @@ -306,4 +278,4 @@ def sync_worker() -> object: future = pool.submit(ctx_snapshot.run, sync_worker) result = future.result() - assert result is text_auth + assert result is openai_auth