From dd4cea531fc0d2dd0c0c9c7015d3dd5b1de47fa1 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Wed, 15 Apr 2026 12:25:25 +0200 Subject: [PATCH] revert: AuthenticationContext ambient primitive (#259) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit ff60c1989212e573be2dcecc91ea2107a2c36943. Reverting per YAGNI + design error found in downstream analysis. The shipped AuthenticationContext was keyed per-Provider, assuming a single Authentication is bound to a provider. Downstream analysis in a multi-modal, multi-provider context revealed this collapses cases where a single provider legitimately exposes DIFFERENT auth for different (modality, operation) pairs — e.g. OAuth for one operation and API-key for another. A per-Provider AuthenticationContext cannot represent both in the same scope and silently drops one. Independent of that correctness gap, the primitive was not actually needed: primitive peer SDKs (openai-python, anthropic-sdk-python, google-genai) have no ambient auth state and handle multi-tenancy via per-request client construction. celeste-python's own create_client() already supports this natively — it returns a ModalityClient with the auth bound at construction, which is the primitive-tier "bind once, call many" pattern. An ambient ContextVar-backed primitive is a framework-tier idea grafted onto the primitive layer. No downstream consumer actually adopted the primitive, so the revert has zero external breakage. create_client(auth=...) remains the supported path for reusable, auth-bound clients. --- src/celeste/__init__.py | 24 +---- src/celeste/auth.py | 70 +------------ src/celeste/exceptions.py | 11 -- tests/unit_tests/test_auth.py | 183 +--------------------------------- 4 files changed, 4 insertions(+), 284 deletions(-) diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index 5663e1cf..d2e78ef2 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -6,15 +6,7 @@ from pydantic import SecretStr from celeste import providers as _providers # noqa: F401 -from celeste.auth import ( - APIKey, - Authentication, - AuthenticationContext, - AuthHeader, - NoAuth, - authentication_scope, - resolve_authentication, -) +from celeste.auth import APIKey, Authentication, AuthHeader, NoAuth from celeste.client import ModalityClient from celeste.core import ( Capability, @@ -27,7 +19,6 @@ from celeste.exceptions import ( ClientNotFoundError, Error, - MissingAuthenticationError, ModelNotFoundError, ) from celeste.io import Input, Output, Usage @@ -196,9 +187,6 @@ def create_client( model: Model object, string model ID, or None for auto-selection. api_key: Optional API key override (string or SecretStr). auth: Optional Authentication object for custom auth (e.g., GoogleADC). - When None and api_key is also None, falls back to the ambient - AuthenticationContext bound by ``authentication_scope(...)`` for - the resolved model's provider, before the env-credential path. protocol: Wire format protocol for compatible APIs (e.g., "openresponses", "chatcompletions"). Use with base_url for third-party compatible APIs. base_url: Custom base URL override. Use with protocol for compatible APIs, @@ -211,8 +199,6 @@ def create_client( ModelNotFoundError: If no model found for the specified capability/provider. ClientNotFoundError: If no client registered for capability/provider/protocol. MissingCredentialsError: If required credentials are not configured. - MissingAuthenticationError: If an ambient AuthenticationContext is bound - but has no entry for the resolved model's provider. ValueError: If capability/operation cannot be inferred from model. """ # Translation layer: convert deprecated capability to modality/operation @@ -262,11 +248,6 @@ def create_client( raise ClientNotFoundError(modality=resolved_modality, provider=target) modality_client_class = _CLIENT_MAP[(resolved_modality, target)] - # Ambient fallback: only when neither auth nor api_key was passed and the - # provider is known. Explicit kwargs always win. - if auth is None and api_key is None and resolved_provider is not None: - auth = resolve_authentication(resolved_provider) - # Auth resolution: BYOA for protocol path, credentials for provider path if resolved_protocol is not None and resolved_provider is None: if auth is not None: @@ -295,14 +276,12 @@ def create_client( __all__ = [ "APIKey", "Authentication", - "AuthenticationContext", "Capability", "CodeExecution", "Content", "Error", "Input", "Message", - "MissingAuthenticationError", "Modality", "Model", "Operation", @@ -318,7 +297,6 @@ def create_client( "WebSearch", "XSearch", "audio", - "authentication_scope", "create_client", "documents", "get_model", diff --git a/src/celeste/auth.py b/src/celeste/auth.py index 18f82081..ce127bc3 100644 --- a/src/celeste/auth.py +++ b/src/celeste/auth.py @@ -1,14 +1,8 @@ """Authentication methods for Celeste providers.""" from abc import ABC, abstractmethod -from collections.abc import Iterator, Mapping -from contextlib import contextmanager -from contextvars import ContextVar -from pydantic import BaseModel, ConfigDict, SecretStr, field_validator - -from celeste.core import Provider -from celeste.exceptions import MissingAuthenticationError +from pydantic import BaseModel, SecretStr, field_validator # Module-level registry (same pattern as _clients and _models) _auth_classes: dict[str, type["Authentication"]] = {} @@ -92,73 +86,11 @@ def get_auth_class(auth_type: str) -> type[Authentication]: return _auth_classes[auth_type] -class AuthenticationContext(BaseModel): - """Per-provider authentication bound to an async scope.""" - - # Frozen: asyncio.gather siblings share the context snapshot by reference, - # so mutability would cause silent cross-task credential bleed. - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) - - entries: Mapping[Provider, Authentication] - - -_current_context: ContextVar[AuthenticationContext | None] = ContextVar( - "celeste.auth.current_context", - default=None, -) - - -@contextmanager -def authentication_scope( - context: AuthenticationContext | None, -) -> Iterator[None]: - """Bind an authentication context for the current async scope. - - Within the ``with`` block, calls to ``celeste..(...)`` - that don't pass an explicit ``auth=`` or ``api_key=`` resolve their - authentication from the bound context using the resolved model's provider. - - Args: - context: The AuthenticationContext to bind, or None to clear any - outer binding for the duration of the block. - """ - token = _current_context.set(context) - try: - yield - finally: - _current_context.reset(token) - - -def resolve_authentication(provider: Provider) -> Authentication | None: - """Look up the provider in the current ambient context. - - Args: - provider: The provider to look up. - - Returns: - The bound Authentication, or None if no ambient context is active. - - Raises: - MissingAuthenticationError: A scope is bound but has no authentication - for the requested provider. - """ - context = _current_context.get() - if context is None: - return None - auth = context.entries.get(provider) - if auth is None: - raise MissingAuthenticationError(provider) - return auth - - __all__ = [ "APIKey", "AuthHeader", "Authentication", - "AuthenticationContext", "NoAuth", - "authentication_scope", "get_auth_class", "register_auth", - "resolve_authentication", ] diff --git a/src/celeste/exceptions.py b/src/celeste/exceptions.py index 9b11c375..83f74ac8 100644 --- a/src/celeste/exceptions.py +++ b/src/celeste/exceptions.py @@ -2,8 +2,6 @@ from typing import Any -from celeste.core import Provider - class Error(Exception): """Base exception for all Celeste errors.""" @@ -235,14 +233,6 @@ def __init__(self, provider: str) -> None: ) -class MissingAuthenticationError(CredentialsError): - """Raised when authentication cannot be resolved for a provider.""" - - def __init__(self, provider: Provider) -> None: - self.provider = provider - super().__init__(f"No authentication configured for provider {provider.value}") - - class InvalidToolError(ValidationError): """Raised when a tool item is not a Tool instance or dict.""" @@ -273,7 +263,6 @@ class UnsupportedParameterWarning(UserWarning): "ConstraintViolationError", "Error", "InvalidToolError", - "MissingAuthenticationError", "MissingCredentialsError", "MissingDependencyError", "ModalityNotFoundError", diff --git a/tests/unit_tests/test_auth.py b/tests/unit_tests/test_auth.py index 0714bc58..3212e723 100644 --- a/tests/unit_tests/test_auth.py +++ b/tests/unit_tests/test_auth.py @@ -1,26 +1,17 @@ -"""Tests for authentication classes, registry, and ambient scope.""" +"""Tests for authentication classes and registry.""" -import asyncio -import contextvars from collections.abc import Generator -from concurrent.futures import ThreadPoolExecutor import pytest -from pydantic import SecretStr, ValidationError +from pydantic import SecretStr from celeste.auth import ( APIKey, Authentication, - AuthenticationContext, AuthHeader, - NoAuth, - authentication_scope, get_auth_class, register_auth, - resolve_authentication, ) -from celeste.core import Provider -from celeste.exceptions import MissingAuthenticationError @pytest.fixture(autouse=True) @@ -109,173 +100,3 @@ def test_get_auth_class_with_unknown_type_raises(self) -> None: ValueError, match=r"Unknown auth type: nonexistent.*Available:" ): get_auth_class("nonexistent") - - -@pytest.fixture -def openai_auth() -> AuthHeader: - return AuthHeader(secret=SecretStr("openai-key")) - - -@pytest.fixture -def anthropic_auth() -> AuthHeader: - return AuthHeader(secret=SecretStr("anthropic-key")) - - -@pytest.fixture -def elevenlabs_auth() -> NoAuth: - return NoAuth() - - -@pytest.fixture -def full_context( - openai_auth: AuthHeader, - anthropic_auth: AuthHeader, - elevenlabs_auth: NoAuth, -) -> AuthenticationContext: - return AuthenticationContext( - entries={ - Provider.OPENAI: openai_auth, - Provider.ANTHROPIC: anthropic_auth, - Provider.ELEVENLABS: elevenlabs_auth, - } - ) - - -class TestAuthenticationContext: - """Test AuthenticationContext pydantic model.""" - - def test_frozen_rejects_mutation(self, full_context: AuthenticationContext) -> None: - with pytest.raises(ValidationError): - full_context.entries = {} # type: ignore[misc] - - -class TestAuthenticationScope: - """Test authentication_scope context manager.""" - - def test_outside_scope_resolves_none(self) -> None: - assert resolve_authentication(Provider.OPENAI) is None - - def test_inside_scope_resolves_bound_auth( - self, full_context: AuthenticationContext, openai_auth: AuthHeader - ) -> None: - with authentication_scope(full_context): - assert resolve_authentication(Provider.OPENAI) is openai_auth - - def test_exit_restores_previous_state( - self, full_context: AuthenticationContext, openai_auth: AuthHeader - ) -> None: - assert resolve_authentication(Provider.OPENAI) is None - with authentication_scope(full_context): - assert resolve_authentication(Provider.OPENAI) is openai_auth - assert resolve_authentication(Provider.OPENAI) is None - - def test_nested_scopes_inner_wins_outer_restored( - self, - full_context: AuthenticationContext, - openai_auth: AuthHeader, - ) -> None: - inner_auth = AuthHeader(secret=SecretStr("inner-key")) - inner_context = AuthenticationContext(entries={Provider.OPENAI: inner_auth}) - - with authentication_scope(full_context): - assert resolve_authentication(Provider.OPENAI) is openai_auth - with authentication_scope(inner_context): - assert resolve_authentication(Provider.OPENAI) is inner_auth - assert resolve_authentication(Provider.OPENAI) is openai_auth - - def test_scope_with_none_clears_binding( - self, full_context: AuthenticationContext, openai_auth: AuthHeader - ) -> None: - with authentication_scope(full_context): - assert resolve_authentication(Provider.OPENAI) is openai_auth - with authentication_scope(None): - assert resolve_authentication(Provider.OPENAI) is None - assert resolve_authentication(Provider.OPENAI) is openai_auth - - -class TestResolveAuthentication: - """Test resolve_authentication lookup helper.""" - - def test_scope_bound_but_provider_missing_raises( - self, full_context: AuthenticationContext - ) -> None: - with ( - authentication_scope(full_context), - pytest.raises(MissingAuthenticationError) as exc_info, - ): - resolve_authentication(Provider.GOOGLE) - err = exc_info.value - assert err.provider is Provider.GOOGLE - assert "google" in str(err) - - -class TestAsyncPropagation: - """Test ContextVar propagation across async boundaries.""" - - @pytest.mark.asyncio - async def test_create_task_propagates( - self, full_context: AuthenticationContext, anthropic_auth: AuthHeader - ) -> None: - async def child() -> object: - return resolve_authentication(Provider.ANTHROPIC) - - with authentication_scope(full_context): - task = asyncio.create_task(child()) - result = await task - - assert result is anthropic_auth - - @pytest.mark.asyncio - async def test_gather_siblings_share_snapshot( - self, full_context: AuthenticationContext, anthropic_auth: AuthHeader - ) -> None: - async def child() -> object: - return resolve_authentication(Provider.ANTHROPIC) - - with authentication_scope(full_context): - results = await asyncio.gather(child(), child(), child()) - - assert all(r is anthropic_auth for r in results) - - @pytest.mark.asyncio - async def test_to_thread_propagates( - self, full_context: AuthenticationContext, openai_auth: AuthHeader - ) -> None: - def sync_worker() -> object: - return resolve_authentication(Provider.OPENAI) - - with authentication_scope(full_context): - result = await asyncio.to_thread(sync_worker) - - assert result is openai_auth - - def test_raw_thread_pool_does_not_propagate( - self, full_context: AuthenticationContext - ) -> None: - def sync_worker() -> object: - return resolve_authentication(Provider.OPENAI) - - with ( - ThreadPoolExecutor(max_workers=1) as pool, - authentication_scope(full_context), - ): - future = pool.submit(sync_worker) - result = future.result() - - assert result is None - - def test_raw_thread_pool_with_copy_context_propagates( - self, full_context: AuthenticationContext, openai_auth: AuthHeader - ) -> None: - def sync_worker() -> object: - return resolve_authentication(Provider.OPENAI) - - with ( - ThreadPoolExecutor(max_workers=1) as pool, - authentication_scope(full_context), - ): - ctx_snapshot = contextvars.copy_context() - future = pool.submit(ctx_snapshot.run, sync_worker) - result = future.result() - - assert result is openai_auth