Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion integrations/openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pip install git+https://git@github.com/databricks/databricks-ai-bridge.git#subdi
## Key Features

- **Vector Search:** Store and query vector representations using `VectorSearchRetrieverTool`.
- **OpenAI-compatible clients:** Use Databricks authentication with OpenAI SDK resources,
including optional separate routing for Conversations API calls.

## Getting Started

Expand Down Expand Up @@ -56,6 +58,31 @@ second_response = client.chat.completions.create(
)
```

### Use Conversations API state

Conversations API calls default to `{workspace_url}/api/2.1/unity-catalog`, even when
Responses API calls use another Databricks OpenAI-compatible base URL such as AI Gateway.
Use `conversations_base_url` only when the Conversations API is served from a custom path.

```python
from databricks.sdk import WorkspaceClient
from databricks_openai import DatabricksOpenAI

workspace_client = WorkspaceClient()

client = DatabricksOpenAI(
workspace_client=workspace_client,
use_ai_gateway=True,
)

conversation = client.conversations.create()
response = client.responses.create(
model="databricks-claude-sonnet-4-5",
conversation=conversation.id,
input="Tell me about Databricks",
)
```

---

## Contribution Guide
Expand All @@ -65,4 +92,3 @@ We welcome contributions! Please see our [contribution guidelines](https://githu
This project is licensed under the [MIT License](LICENSE).

Thank you for using Databricks OpenAI!

50 changes: 50 additions & 0 deletions integrations/openai/src/databricks_openai/utils/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, OpenAI
from openai.resources.chat import AsyncChat, Chat
from openai.resources.chat.completions import AsyncCompletions, Completions
from openai.resources.conversations import AsyncConversations, Conversations
from openai.resources.responses import AsyncResponses, Responses
from typing_extensions import override

Expand Down Expand Up @@ -137,6 +138,15 @@ def _resolve_base_url(
return f"{host}/serving-endpoints"


def _resolve_conversations_base_url(
workspace_client: WorkspaceClient,
conversations_base_url: str | None,
) -> str:
if conversations_base_url is not None:
return conversations_base_url
return f"{workspace_client.config.host}/api/2.1/unity-catalog"


def _get_authorized_http_client(workspace_client: WorkspaceClient) -> Client:
databricks_token_auth = BearerAuth(workspace_client.config.authenticate)
return Client(auth=databricks_token_auth)
Expand Down Expand Up @@ -337,6 +347,8 @@ class DatabricksOpenAI(OpenAI):
base_url: Optional base URL to override the default serving endpoints URL. When the URL
points to a Databricks App (contains "databricksapps"), OAuth authentication is
required.
conversations_base_url: Optional base URL to use for OpenAI Conversations API calls.
Defaults to ``{workspace_url}/api/2.1/unity-catalog``.
use_ai_gateway_native_api: If True, auto-detect AI Gateway V2 and route requests through
its native OpenAI-compatible API (``<ai_gateway_url>/openai/v1``). This allows use of
provider-native features not available through the MLflow API. Cannot be combined
Expand Down Expand Up @@ -385,6 +397,7 @@ def __init__(
self,
workspace_client: WorkspaceClient | None = None,
base_url: str | None = None,
conversations_base_url: str | None = None,
use_ai_gateway_native_api: bool = False,
use_ai_gateway: bool = False,
**kwargs: Any,
Expand All @@ -393,6 +406,9 @@ def __init__(
workspace_client = WorkspaceClient()

self._workspace_client = workspace_client
self._conversations_base_url = _resolve_conversations_base_url(
workspace_client, conversations_base_url
)

target_base_url = _resolve_base_url(
workspace_client, base_url, use_ai_gateway, use_ai_gateway_native_api
Expand Down Expand Up @@ -425,6 +441,20 @@ def responses(self) -> Responses:
self._databricks_responses = DatabricksResponses(self, self._workspace_client)
return self._databricks_responses

@property
def conversations(self) -> Conversations:
if not hasattr(self, "_databricks_conversations_client"):
self._databricks_conversations_client = OpenAI(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should use DatabricksOpenAI here right so workspace client authorization flows through?

base_url=self._conversations_base_url,
api_key=self.api_key,
http_client=self._client,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self._custom_headers,
default_query=self._custom_query,
)
return self._databricks_conversations_client.conversations


class AsyncDatabricksCompletions(AsyncCompletions):
"""Async completions that conditionally strips 'strict' from tools for non-GPT models."""
Expand Down Expand Up @@ -502,6 +532,8 @@ class AsyncDatabricksOpenAI(AsyncOpenAI):
base_url: Optional base URL to override the default serving endpoints URL. When the URL
points to a Databricks App (contains "databricksapps"), OAuth authentication is
required.
conversations_base_url: Optional base URL to use for OpenAI Conversations API calls.
Defaults to ``{workspace_url}/api/2.1/unity-catalog``.
use_ai_gateway_native_api: If True, auto-detect AI Gateway V2 and route requests through
its native OpenAI-compatible API (``<ai_gateway_url>/openai/v1``). This allows use of
provider-native features not available through the MLflow API. Cannot be combined
Expand Down Expand Up @@ -550,6 +582,7 @@ def __init__(
self,
workspace_client: WorkspaceClient | None = None,
base_url: str | None = None,
conversations_base_url: str | None = None,
use_ai_gateway_native_api: bool = False,
use_ai_gateway: bool = False,
**kwargs: Any,
Expand All @@ -558,6 +591,9 @@ def __init__(
workspace_client = WorkspaceClient()

self._workspace_client = workspace_client
self._conversations_base_url = _resolve_conversations_base_url(
workspace_client, conversations_base_url
)

target_base_url = _resolve_base_url(
workspace_client, base_url, use_ai_gateway, use_ai_gateway_native_api
Expand Down Expand Up @@ -588,3 +624,17 @@ def responses(self) -> AsyncResponses:
if not hasattr(self, "_databricks_responses"):
self._databricks_responses = AsyncDatabricksResponses(self, self._workspace_client)
return self._databricks_responses

@property
def conversations(self) -> AsyncConversations:
if not hasattr(self, "_databricks_conversations_client"):
self._databricks_conversations_client = AsyncOpenAI(
base_url=self._conversations_base_url,
api_key=self.api_key,
http_client=self._client,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self._custom_headers,
default_query=self._custom_query,
)
return self._databricks_conversations_client.conversations
43 changes: 43 additions & 0 deletions integrations/openai/tests/unit_tests/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, OpenAI
from openai._types import NOT_GIVEN, Omit
from openai.resources.chat.completions import AsyncCompletions, Completions
from openai.resources.conversations import AsyncConversations, Conversations
from openai.resources.responses import AsyncResponses, Responses

from databricks_openai import AsyncDatabricksOpenAI, DatabricksOpenAI
Expand Down Expand Up @@ -546,6 +547,48 @@ def test_init_with_non_databricksapps_base_url_does_not_require_oauth(
mock_workspace_client_no_oauth.config.oauth_token.assert_not_called()


class TestConversationsBaseUrl:
"""Tests for routing OpenAI Conversations API calls to a separate base URL."""

def test_sync_conversations_use_default_unity_catalog_base_url(self, mock_workspace_client):
client = DatabricksOpenAI(workspace_client=mock_workspace_client, use_ai_gateway=True)

assert isinstance(client.conversations, Conversations)
assert "/ai-gateway/mlflow/v1/" in str(client.base_url)
assert "/api/2.1/unity-catalog/" in str(client.conversations._client.base_url)

def test_sync_conversations_use_override_base_url(self, mock_workspace_client):
client = DatabricksOpenAI(
workspace_client=mock_workspace_client,
use_ai_gateway=True,
conversations_base_url="https://test.databricks.com/serving-endpoints",
)

assert "/ai-gateway/mlflow/v1/" in str(client.base_url)
assert "/serving-endpoints/" in str(client.conversations._client.base_url)
assert client.conversations is client.conversations

def test_async_conversations_use_default_unity_catalog_base_url(self, mock_workspace_client):
client = AsyncDatabricksOpenAI(
workspace_client=mock_workspace_client, use_ai_gateway=True
)

assert isinstance(client.conversations, AsyncConversations)
assert "/ai-gateway/mlflow/v1/" in str(client.base_url)
assert "/api/2.1/unity-catalog/" in str(client.conversations._client.base_url)

def test_async_conversations_use_override_base_url(self, mock_workspace_client):
client = AsyncDatabricksOpenAI(
workspace_client=mock_workspace_client,
use_ai_gateway=True,
conversations_base_url="https://test.databricks.com/serving-endpoints",
)

assert "/ai-gateway/mlflow/v1/" in str(client.base_url)
assert "/serving-endpoints/" in str(client.conversations._client.base_url)
assert client.conversations is client.conversations


class TestAppsRouting:
"""Tests for apps/ prefix routing in DatabricksOpenAI and AsyncDatabricksOpenAI."""

Expand Down
Loading