diff --git a/integrations/openai/README.md b/integrations/openai/README.md index 9945591b..3ff6ed97 100644 --- a/integrations/openai/README.md +++ b/integrations/openai/README.md @@ -17,6 +17,8 @@ pip install git+https://git@github.com/databricks/databricks-ai-bridge.git#subdi ## Key Features - **Vector Search:** Store and query vector representations using `VectorSearchRetrieverTool`. +- **OpenAI-compatible clients:** Use Databricks authentication with OpenAI SDK resources, + including optional separate routing for Conversations API calls. ## Getting Started @@ -56,6 +58,31 @@ second_response = client.chat.completions.create( ) ``` +### Use Conversations API state + +Conversations API calls default to `{workspace_url}/api/2.1/unity-catalog`, even when +Responses API calls use another Databricks OpenAI-compatible base URL such as AI Gateway. +Use `conversations_base_url` only when the Conversations API is served from a custom path. + +```python +from databricks.sdk import WorkspaceClient +from databricks_openai import DatabricksOpenAI + +workspace_client = WorkspaceClient() + +client = DatabricksOpenAI( + workspace_client=workspace_client, + use_ai_gateway=True, +) + +conversation = client.conversations.create() +response = client.responses.create( + model="databricks-claude-sonnet-4-5", + conversation=conversation.id, + input="Tell me about Databricks", +) +``` + --- ## Contribution Guide @@ -65,4 +92,3 @@ We welcome contributions! Please see our [contribution guidelines](https://githu This project is licensed under the [MIT License](LICENSE). Thank you for using Databricks OpenAI! - diff --git a/integrations/openai/src/databricks_openai/utils/clients.py b/integrations/openai/src/databricks_openai/utils/clients.py index 28138bdc..bc097a04 100644 --- a/integrations/openai/src/databricks_openai/utils/clients.py +++ b/integrations/openai/src/databricks_openai/utils/clients.py @@ -6,6 +6,7 @@ from openai import APIConnectionError, APIStatusError, AsyncOpenAI, OpenAI from openai.resources.chat import AsyncChat, Chat from openai.resources.chat.completions import AsyncCompletions, Completions +from openai.resources.conversations import AsyncConversations, Conversations from openai.resources.responses import AsyncResponses, Responses from typing_extensions import override @@ -137,6 +138,15 @@ def _resolve_base_url( return f"{host}/serving-endpoints" +def _resolve_conversations_base_url( + workspace_client: WorkspaceClient, + conversations_base_url: str | None, +) -> str: + if conversations_base_url is not None: + return conversations_base_url + return f"{workspace_client.config.host}/api/2.1/unity-catalog" + + def _get_authorized_http_client(workspace_client: WorkspaceClient) -> Client: databricks_token_auth = BearerAuth(workspace_client.config.authenticate) return Client(auth=databricks_token_auth) @@ -337,6 +347,8 @@ class DatabricksOpenAI(OpenAI): base_url: Optional base URL to override the default serving endpoints URL. When the URL points to a Databricks App (contains "databricksapps"), OAuth authentication is required. + conversations_base_url: Optional base URL to use for OpenAI Conversations API calls. + Defaults to ``{workspace_url}/api/2.1/unity-catalog``. use_ai_gateway_native_api: If True, auto-detect AI Gateway V2 and route requests through its native OpenAI-compatible API (``/openai/v1``). This allows use of provider-native features not available through the MLflow API. Cannot be combined @@ -385,6 +397,7 @@ def __init__( self, workspace_client: WorkspaceClient | None = None, base_url: str | None = None, + conversations_base_url: str | None = None, use_ai_gateway_native_api: bool = False, use_ai_gateway: bool = False, **kwargs: Any, @@ -393,6 +406,9 @@ def __init__( workspace_client = WorkspaceClient() self._workspace_client = workspace_client + self._conversations_base_url = _resolve_conversations_base_url( + workspace_client, conversations_base_url + ) target_base_url = _resolve_base_url( workspace_client, base_url, use_ai_gateway, use_ai_gateway_native_api @@ -425,6 +441,20 @@ def responses(self) -> Responses: self._databricks_responses = DatabricksResponses(self, self._workspace_client) return self._databricks_responses + @property + def conversations(self) -> Conversations: + if not hasattr(self, "_databricks_conversations_client"): + self._databricks_conversations_client = OpenAI( + base_url=self._conversations_base_url, + api_key=self.api_key, + http_client=self._client, + timeout=self.timeout, + max_retries=self.max_retries, + default_headers=self._custom_headers, + default_query=self._custom_query, + ) + return self._databricks_conversations_client.conversations + class AsyncDatabricksCompletions(AsyncCompletions): """Async completions that conditionally strips 'strict' from tools for non-GPT models.""" @@ -502,6 +532,8 @@ class AsyncDatabricksOpenAI(AsyncOpenAI): base_url: Optional base URL to override the default serving endpoints URL. When the URL points to a Databricks App (contains "databricksapps"), OAuth authentication is required. + conversations_base_url: Optional base URL to use for OpenAI Conversations API calls. + Defaults to ``{workspace_url}/api/2.1/unity-catalog``. use_ai_gateway_native_api: If True, auto-detect AI Gateway V2 and route requests through its native OpenAI-compatible API (``/openai/v1``). This allows use of provider-native features not available through the MLflow API. Cannot be combined @@ -550,6 +582,7 @@ def __init__( self, workspace_client: WorkspaceClient | None = None, base_url: str | None = None, + conversations_base_url: str | None = None, use_ai_gateway_native_api: bool = False, use_ai_gateway: bool = False, **kwargs: Any, @@ -558,6 +591,9 @@ def __init__( workspace_client = WorkspaceClient() self._workspace_client = workspace_client + self._conversations_base_url = _resolve_conversations_base_url( + workspace_client, conversations_base_url + ) target_base_url = _resolve_base_url( workspace_client, base_url, use_ai_gateway, use_ai_gateway_native_api @@ -588,3 +624,17 @@ def responses(self) -> AsyncResponses: if not hasattr(self, "_databricks_responses"): self._databricks_responses = AsyncDatabricksResponses(self, self._workspace_client) return self._databricks_responses + + @property + def conversations(self) -> AsyncConversations: + if not hasattr(self, "_databricks_conversations_client"): + self._databricks_conversations_client = AsyncOpenAI( + base_url=self._conversations_base_url, + api_key=self.api_key, + http_client=self._client, + timeout=self.timeout, + max_retries=self.max_retries, + default_headers=self._custom_headers, + default_query=self._custom_query, + ) + return self._databricks_conversations_client.conversations diff --git a/integrations/openai/tests/unit_tests/test_clients.py b/integrations/openai/tests/unit_tests/test_clients.py index bd23211a..badfdeb9 100644 --- a/integrations/openai/tests/unit_tests/test_clients.py +++ b/integrations/openai/tests/unit_tests/test_clients.py @@ -9,6 +9,7 @@ from openai import APIConnectionError, APIStatusError, AsyncOpenAI, OpenAI from openai._types import NOT_GIVEN, Omit from openai.resources.chat.completions import AsyncCompletions, Completions +from openai.resources.conversations import AsyncConversations, Conversations from openai.resources.responses import AsyncResponses, Responses from databricks_openai import AsyncDatabricksOpenAI, DatabricksOpenAI @@ -546,6 +547,48 @@ def test_init_with_non_databricksapps_base_url_does_not_require_oauth( mock_workspace_client_no_oauth.config.oauth_token.assert_not_called() +class TestConversationsBaseUrl: + """Tests for routing OpenAI Conversations API calls to a separate base URL.""" + + def test_sync_conversations_use_default_unity_catalog_base_url(self, mock_workspace_client): + client = DatabricksOpenAI(workspace_client=mock_workspace_client, use_ai_gateway=True) + + assert isinstance(client.conversations, Conversations) + assert "/ai-gateway/mlflow/v1/" in str(client.base_url) + assert "/api/2.1/unity-catalog/" in str(client.conversations._client.base_url) + + def test_sync_conversations_use_override_base_url(self, mock_workspace_client): + client = DatabricksOpenAI( + workspace_client=mock_workspace_client, + use_ai_gateway=True, + conversations_base_url="https://test.databricks.com/serving-endpoints", + ) + + assert "/ai-gateway/mlflow/v1/" in str(client.base_url) + assert "/serving-endpoints/" in str(client.conversations._client.base_url) + assert client.conversations is client.conversations + + def test_async_conversations_use_default_unity_catalog_base_url(self, mock_workspace_client): + client = AsyncDatabricksOpenAI( + workspace_client=mock_workspace_client, use_ai_gateway=True + ) + + assert isinstance(client.conversations, AsyncConversations) + assert "/ai-gateway/mlflow/v1/" in str(client.base_url) + assert "/api/2.1/unity-catalog/" in str(client.conversations._client.base_url) + + def test_async_conversations_use_override_base_url(self, mock_workspace_client): + client = AsyncDatabricksOpenAI( + workspace_client=mock_workspace_client, + use_ai_gateway=True, + conversations_base_url="https://test.databricks.com/serving-endpoints", + ) + + assert "/ai-gateway/mlflow/v1/" in str(client.base_url) + assert "/serving-endpoints/" in str(client.conversations._client.base_url) + assert client.conversations is client.conversations + + class TestAppsRouting: """Tests for apps/ prefix routing in DatabricksOpenAI and AsyncDatabricksOpenAI."""