diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index d2e78ef..5663e1c 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -6,7 +6,15 @@ from pydantic import SecretStr from celeste import providers as _providers # noqa: F401 -from celeste.auth import APIKey, Authentication, AuthHeader, NoAuth +from celeste.auth import ( + APIKey, + Authentication, + AuthenticationContext, + AuthHeader, + NoAuth, + authentication_scope, + resolve_authentication, +) from celeste.client import ModalityClient from celeste.core import ( Capability, @@ -19,6 +27,7 @@ from celeste.exceptions import ( ClientNotFoundError, Error, + MissingAuthenticationError, ModelNotFoundError, ) from celeste.io import Input, Output, Usage @@ -187,6 +196,9 @@ def create_client( model: Model object, string model ID, or None for auto-selection. api_key: Optional API key override (string or SecretStr). auth: Optional Authentication object for custom auth (e.g., GoogleADC). + When None and api_key is also None, falls back to the ambient + AuthenticationContext bound by ``authentication_scope(...)`` for + the resolved model's provider, before the env-credential path. protocol: Wire format protocol for compatible APIs (e.g., "openresponses", "chatcompletions"). Use with base_url for third-party compatible APIs. base_url: Custom base URL override. Use with protocol for compatible APIs, @@ -199,6 +211,8 @@ def create_client( ModelNotFoundError: If no model found for the specified capability/provider. ClientNotFoundError: If no client registered for capability/provider/protocol. MissingCredentialsError: If required credentials are not configured. + MissingAuthenticationError: If an ambient AuthenticationContext is bound + but has no entry for the resolved model's provider. ValueError: If capability/operation cannot be inferred from model. """ # Translation layer: convert deprecated capability to modality/operation @@ -248,6 +262,11 @@ def create_client( raise ClientNotFoundError(modality=resolved_modality, provider=target) modality_client_class = _CLIENT_MAP[(resolved_modality, target)] + # Ambient fallback: only when neither auth nor api_key was passed and the + # provider is known. Explicit kwargs always win. + if auth is None and api_key is None and resolved_provider is not None: + auth = resolve_authentication(resolved_provider) + # Auth resolution: BYOA for protocol path, credentials for provider path if resolved_protocol is not None and resolved_provider is None: if auth is not None: @@ -276,12 +295,14 @@ def create_client( __all__ = [ "APIKey", "Authentication", + "AuthenticationContext", "Capability", "CodeExecution", "Content", "Error", "Input", "Message", + "MissingAuthenticationError", "Modality", "Model", "Operation", @@ -297,6 +318,7 @@ def create_client( "WebSearch", "XSearch", "audio", + "authentication_scope", "create_client", "documents", "get_model", diff --git a/src/celeste/auth.py b/src/celeste/auth.py index ce127bc..18f8208 100644 --- a/src/celeste/auth.py +++ b/src/celeste/auth.py @@ -1,8 +1,14 @@ """Authentication methods for Celeste providers.""" from abc import ABC, abstractmethod +from collections.abc import Iterator, Mapping +from contextlib import contextmanager +from contextvars import ContextVar -from pydantic import BaseModel, SecretStr, field_validator +from pydantic import BaseModel, ConfigDict, SecretStr, field_validator + +from celeste.core import Provider +from celeste.exceptions import MissingAuthenticationError # Module-level registry (same pattern as _clients and _models) _auth_classes: dict[str, type["Authentication"]] = {} @@ -86,11 +92,73 @@ def get_auth_class(auth_type: str) -> type[Authentication]: return _auth_classes[auth_type] +class AuthenticationContext(BaseModel): + """Per-provider authentication bound to an async scope.""" + + # Frozen: asyncio.gather siblings share the context snapshot by reference, + # so mutability would cause silent cross-task credential bleed. + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + + entries: Mapping[Provider, Authentication] + + +_current_context: ContextVar[AuthenticationContext | None] = ContextVar( + "celeste.auth.current_context", + default=None, +) + + +@contextmanager +def authentication_scope( + context: AuthenticationContext | None, +) -> Iterator[None]: + """Bind an authentication context for the current async scope. + + Within the ``with`` block, calls to ``celeste..(...)`` + that don't pass an explicit ``auth=`` or ``api_key=`` resolve their + authentication from the bound context using the resolved model's provider. + + Args: + context: The AuthenticationContext to bind, or None to clear any + outer binding for the duration of the block. + """ + token = _current_context.set(context) + try: + yield + finally: + _current_context.reset(token) + + +def resolve_authentication(provider: Provider) -> Authentication | None: + """Look up the provider in the current ambient context. + + Args: + provider: The provider to look up. + + Returns: + The bound Authentication, or None if no ambient context is active. + + Raises: + MissingAuthenticationError: A scope is bound but has no authentication + for the requested provider. + """ + context = _current_context.get() + if context is None: + return None + auth = context.entries.get(provider) + if auth is None: + raise MissingAuthenticationError(provider) + return auth + + __all__ = [ "APIKey", "AuthHeader", "Authentication", + "AuthenticationContext", "NoAuth", + "authentication_scope", "get_auth_class", "register_auth", + "resolve_authentication", ] diff --git a/src/celeste/exceptions.py b/src/celeste/exceptions.py index 83f74ac..9b11c37 100644 --- a/src/celeste/exceptions.py +++ b/src/celeste/exceptions.py @@ -2,6 +2,8 @@ from typing import Any +from celeste.core import Provider + class Error(Exception): """Base exception for all Celeste errors.""" @@ -233,6 +235,14 @@ def __init__(self, provider: str) -> None: ) +class MissingAuthenticationError(CredentialsError): + """Raised when authentication cannot be resolved for a provider.""" + + def __init__(self, provider: Provider) -> None: + self.provider = provider + super().__init__(f"No authentication configured for provider {provider.value}") + + class InvalidToolError(ValidationError): """Raised when a tool item is not a Tool instance or dict.""" @@ -263,6 +273,7 @@ class UnsupportedParameterWarning(UserWarning): "ConstraintViolationError", "Error", "InvalidToolError", + "MissingAuthenticationError", "MissingCredentialsError", "MissingDependencyError", "ModalityNotFoundError", diff --git a/tests/unit_tests/test_auth.py b/tests/unit_tests/test_auth.py index 3212e72..0714bc5 100644 --- a/tests/unit_tests/test_auth.py +++ b/tests/unit_tests/test_auth.py @@ -1,17 +1,26 @@ -"""Tests for authentication classes and registry.""" +"""Tests for authentication classes, registry, and ambient scope.""" +import asyncio +import contextvars from collections.abc import Generator +from concurrent.futures import ThreadPoolExecutor import pytest -from pydantic import SecretStr +from pydantic import SecretStr, ValidationError from celeste.auth import ( APIKey, Authentication, + AuthenticationContext, AuthHeader, + NoAuth, + authentication_scope, get_auth_class, register_auth, + resolve_authentication, ) +from celeste.core import Provider +from celeste.exceptions import MissingAuthenticationError @pytest.fixture(autouse=True) @@ -100,3 +109,173 @@ def test_get_auth_class_with_unknown_type_raises(self) -> None: ValueError, match=r"Unknown auth type: nonexistent.*Available:" ): get_auth_class("nonexistent") + + +@pytest.fixture +def openai_auth() -> AuthHeader: + return AuthHeader(secret=SecretStr("openai-key")) + + +@pytest.fixture +def anthropic_auth() -> AuthHeader: + return AuthHeader(secret=SecretStr("anthropic-key")) + + +@pytest.fixture +def elevenlabs_auth() -> NoAuth: + return NoAuth() + + +@pytest.fixture +def full_context( + openai_auth: AuthHeader, + anthropic_auth: AuthHeader, + elevenlabs_auth: NoAuth, +) -> AuthenticationContext: + return AuthenticationContext( + entries={ + Provider.OPENAI: openai_auth, + Provider.ANTHROPIC: anthropic_auth, + Provider.ELEVENLABS: elevenlabs_auth, + } + ) + + +class TestAuthenticationContext: + """Test AuthenticationContext pydantic model.""" + + def test_frozen_rejects_mutation(self, full_context: AuthenticationContext) -> None: + with pytest.raises(ValidationError): + full_context.entries = {} # type: ignore[misc] + + +class TestAuthenticationScope: + """Test authentication_scope context manager.""" + + def test_outside_scope_resolves_none(self) -> None: + assert resolve_authentication(Provider.OPENAI) is None + + def test_inside_scope_resolves_bound_auth( + self, full_context: AuthenticationContext, openai_auth: AuthHeader + ) -> None: + with authentication_scope(full_context): + assert resolve_authentication(Provider.OPENAI) is openai_auth + + def test_exit_restores_previous_state( + self, full_context: AuthenticationContext, openai_auth: AuthHeader + ) -> None: + assert resolve_authentication(Provider.OPENAI) is None + with authentication_scope(full_context): + assert resolve_authentication(Provider.OPENAI) is openai_auth + assert resolve_authentication(Provider.OPENAI) is None + + def test_nested_scopes_inner_wins_outer_restored( + self, + full_context: AuthenticationContext, + openai_auth: AuthHeader, + ) -> None: + inner_auth = AuthHeader(secret=SecretStr("inner-key")) + inner_context = AuthenticationContext(entries={Provider.OPENAI: inner_auth}) + + with authentication_scope(full_context): + assert resolve_authentication(Provider.OPENAI) is openai_auth + with authentication_scope(inner_context): + assert resolve_authentication(Provider.OPENAI) is inner_auth + assert resolve_authentication(Provider.OPENAI) is openai_auth + + def test_scope_with_none_clears_binding( + self, full_context: AuthenticationContext, openai_auth: AuthHeader + ) -> None: + with authentication_scope(full_context): + assert resolve_authentication(Provider.OPENAI) is openai_auth + with authentication_scope(None): + assert resolve_authentication(Provider.OPENAI) is None + assert resolve_authentication(Provider.OPENAI) is openai_auth + + +class TestResolveAuthentication: + """Test resolve_authentication lookup helper.""" + + def test_scope_bound_but_provider_missing_raises( + self, full_context: AuthenticationContext + ) -> None: + with ( + authentication_scope(full_context), + pytest.raises(MissingAuthenticationError) as exc_info, + ): + resolve_authentication(Provider.GOOGLE) + err = exc_info.value + assert err.provider is Provider.GOOGLE + assert "google" in str(err) + + +class TestAsyncPropagation: + """Test ContextVar propagation across async boundaries.""" + + @pytest.mark.asyncio + async def test_create_task_propagates( + self, full_context: AuthenticationContext, anthropic_auth: AuthHeader + ) -> None: + async def child() -> object: + return resolve_authentication(Provider.ANTHROPIC) + + with authentication_scope(full_context): + task = asyncio.create_task(child()) + result = await task + + assert result is anthropic_auth + + @pytest.mark.asyncio + async def test_gather_siblings_share_snapshot( + self, full_context: AuthenticationContext, anthropic_auth: AuthHeader + ) -> None: + async def child() -> object: + return resolve_authentication(Provider.ANTHROPIC) + + with authentication_scope(full_context): + results = await asyncio.gather(child(), child(), child()) + + assert all(r is anthropic_auth for r in results) + + @pytest.mark.asyncio + async def test_to_thread_propagates( + self, full_context: AuthenticationContext, openai_auth: AuthHeader + ) -> None: + def sync_worker() -> object: + return resolve_authentication(Provider.OPENAI) + + with authentication_scope(full_context): + result = await asyncio.to_thread(sync_worker) + + assert result is openai_auth + + def test_raw_thread_pool_does_not_propagate( + self, full_context: AuthenticationContext + ) -> None: + def sync_worker() -> object: + return resolve_authentication(Provider.OPENAI) + + with ( + ThreadPoolExecutor(max_workers=1) as pool, + authentication_scope(full_context), + ): + future = pool.submit(sync_worker) + result = future.result() + + assert result is None + + def test_raw_thread_pool_with_copy_context_propagates( + self, full_context: AuthenticationContext, openai_auth: AuthHeader + ) -> None: + def sync_worker() -> object: + return resolve_authentication(Provider.OPENAI) + + with ( + ThreadPoolExecutor(max_workers=1) as pool, + authentication_scope(full_context), + ): + ctx_snapshot = contextvars.copy_context() + future = pool.submit(ctx_snapshot.run, sync_worker) + result = future.result() + + assert result is openai_auth