Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 1 addition & 23 deletions src/celeste/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,7 @@
from pydantic import SecretStr

from celeste import providers as _providers # noqa: F401
from celeste.auth import (
APIKey,
Authentication,
AuthenticationContext,
AuthHeader,
NoAuth,
authentication_scope,
resolve_authentication,
)
from celeste.auth import APIKey, Authentication, AuthHeader, NoAuth
from celeste.client import ModalityClient
from celeste.core import (
Capability,
Expand All @@ -27,7 +19,6 @@
from celeste.exceptions import (
ClientNotFoundError,
Error,
MissingAuthenticationError,
ModelNotFoundError,
)
from celeste.io import Input, Output, Usage
Expand Down Expand Up @@ -196,9 +187,6 @@ def create_client(
model: Model object, string model ID, or None for auto-selection.
api_key: Optional API key override (string or SecretStr).
auth: Optional Authentication object for custom auth (e.g., GoogleADC).
When None and api_key is also None, falls back to the ambient
AuthenticationContext bound by ``authentication_scope(...)`` for
the resolved model's provider, before the env-credential path.
protocol: Wire format protocol for compatible APIs (e.g., "openresponses",
"chatcompletions"). Use with base_url for third-party compatible APIs.
base_url: Custom base URL override. Use with protocol for compatible APIs,
Expand All @@ -211,8 +199,6 @@ def create_client(
ModelNotFoundError: If no model found for the specified capability/provider.
ClientNotFoundError: If no client registered for capability/provider/protocol.
MissingCredentialsError: If required credentials are not configured.
MissingAuthenticationError: If an ambient AuthenticationContext is bound
but has no entry for the resolved model's provider.
ValueError: If capability/operation cannot be inferred from model.
"""
# Translation layer: convert deprecated capability to modality/operation
Expand Down Expand Up @@ -262,11 +248,6 @@ def create_client(
raise ClientNotFoundError(modality=resolved_modality, provider=target)
modality_client_class = _CLIENT_MAP[(resolved_modality, target)]

# Ambient fallback: only when neither auth nor api_key was passed and the
# provider is known. Explicit kwargs always win.
if auth is None and api_key is None and resolved_provider is not None:
auth = resolve_authentication(resolved_provider)

# Auth resolution: BYOA for protocol path, credentials for provider path
if resolved_protocol is not None and resolved_provider is None:
if auth is not None:
Expand Down Expand Up @@ -295,14 +276,12 @@ def create_client(
__all__ = [
"APIKey",
"Authentication",
"AuthenticationContext",
"Capability",
"CodeExecution",
"Content",
"Error",
"Input",
"Message",
"MissingAuthenticationError",
"Modality",
"Model",
"Operation",
Expand All @@ -318,7 +297,6 @@ def create_client(
"WebSearch",
"XSearch",
"audio",
"authentication_scope",
"create_client",
"documents",
"get_model",
Expand Down
70 changes: 1 addition & 69 deletions src/celeste/auth.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
"""Authentication methods for Celeste providers."""

from abc import ABC, abstractmethod
from collections.abc import Iterator, Mapping
from contextlib import contextmanager
from contextvars import ContextVar

from pydantic import BaseModel, ConfigDict, SecretStr, field_validator

from celeste.core import Provider
from celeste.exceptions import MissingAuthenticationError
from pydantic import BaseModel, SecretStr, field_validator

# Module-level registry (same pattern as _clients and _models)
_auth_classes: dict[str, type["Authentication"]] = {}
Expand Down Expand Up @@ -92,73 +86,11 @@ def get_auth_class(auth_type: str) -> type[Authentication]:
return _auth_classes[auth_type]


class AuthenticationContext(BaseModel):
"""Per-provider authentication bound to an async scope."""

# Frozen: asyncio.gather siblings share the context snapshot by reference,
# so mutability would cause silent cross-task credential bleed.
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)

entries: Mapping[Provider, Authentication]


_current_context: ContextVar[AuthenticationContext | None] = ContextVar(
"celeste.auth.current_context",
default=None,
)


@contextmanager
def authentication_scope(
context: AuthenticationContext | None,
) -> Iterator[None]:
"""Bind an authentication context for the current async scope.

Within the ``with`` block, calls to ``celeste.<modality>.<method>(...)``
that don't pass an explicit ``auth=`` or ``api_key=`` resolve their
authentication from the bound context using the resolved model's provider.

Args:
context: The AuthenticationContext to bind, or None to clear any
outer binding for the duration of the block.
"""
token = _current_context.set(context)
try:
yield
finally:
_current_context.reset(token)


def resolve_authentication(provider: Provider) -> Authentication | None:
"""Look up the provider in the current ambient context.

Args:
provider: The provider to look up.

Returns:
The bound Authentication, or None if no ambient context is active.

Raises:
MissingAuthenticationError: A scope is bound but has no authentication
for the requested provider.
"""
context = _current_context.get()
if context is None:
return None
auth = context.entries.get(provider)
if auth is None:
raise MissingAuthenticationError(provider)
return auth


__all__ = [
"APIKey",
"AuthHeader",
"Authentication",
"AuthenticationContext",
"NoAuth",
"authentication_scope",
"get_auth_class",
"register_auth",
"resolve_authentication",
]
11 changes: 0 additions & 11 deletions src/celeste/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from typing import Any

from celeste.core import Provider


class Error(Exception):
"""Base exception for all Celeste errors."""
Expand Down Expand Up @@ -235,14 +233,6 @@ def __init__(self, provider: str) -> None:
)


class MissingAuthenticationError(CredentialsError):
"""Raised when authentication cannot be resolved for a provider."""

def __init__(self, provider: Provider) -> None:
self.provider = provider
super().__init__(f"No authentication configured for provider {provider.value}")


class InvalidToolError(ValidationError):
"""Raised when a tool item is not a Tool instance or dict."""

Expand Down Expand Up @@ -273,7 +263,6 @@ class UnsupportedParameterWarning(UserWarning):
"ConstraintViolationError",
"Error",
"InvalidToolError",
"MissingAuthenticationError",
"MissingCredentialsError",
"MissingDependencyError",
"ModalityNotFoundError",
Expand Down
Loading
Loading