From f253fba580a609a011f550f73cfbbd287d8c1ace Mon Sep 17 00:00:00 2001 From: MasterKenth Date: Tue, 26 Aug 2025 18:05:02 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Feat(backend/vector):=20add=20suppo?= =?UTF-8?q?rt=20for=20chromadb=20as=20a=20client-server=20connection?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/.env.example | 9 ++ backend/src/conftest.py | 4 +- .../vector/ChromaDBClientVectorService.py | 107 ++++++++++++++++++ ...rvice.py => ChromaDBLocalVectorService.py} | 2 +- backend/src/modules/vector/factory.py | 14 ++- ....py => test_ChromaDBLocalVectorService.py} | 0 6 files changed, 131 insertions(+), 5 deletions(-) create mode 100644 backend/src/modules/vector/ChromaDBClientVectorService.py rename backend/src/modules/vector/{ChromaDBVectorService.py => ChromaDBLocalVectorService.py} (98%) rename backend/src/modules/vector/{test_ChromaDBVectorService.py => test_ChromaDBLocalVectorService.py} (100%) diff --git a/backend/.env.example b/backend/.env.example index 424573e5..d8104178 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -13,3 +13,12 @@ SMTP_HOST="localhost" SMTP_PORT=587 SMTP_USER="my smtp user" SMTP_PASSWORD="my password" + +# Hostname of the ChromaDB server instance to use for file embeddings/vectorization. +# Will default to using a local ChromaDB instance using SQLite if not provided. +# More: https://docs.trychroma.com/docs/run-chroma/client-server +# CHROMADB_HOST=localhost + +# (Optional) Port of ChromaDB server to connect to (if CHROMADB_HOST is also set). +# Defaults to 8000 if not provided. +# CHROMADB_PORT=8000 diff --git a/backend/src/conftest.py b/backend/src/conftest.py index a0e2bec7..e9c0368b 100644 --- a/backend/src/conftest.py +++ b/backend/src/conftest.py @@ -8,7 +8,7 @@ from src.common.get_timestamp import get_timestamp from src.common.mock_services import MockSettingsService -from src.modules.vector.ChromaDBVectorService import ChromaDBVectorService +from src.modules.vector.ChromaDBLocalVectorService import ChromaDBLocalVectorService TEST_DB_PATH = './__test_chromadb' @@ -17,7 +17,7 @@ def vector_service(): path = TEST_DB_PATH + uuid.uuid4().hex shutil.rmtree(path, ignore_errors=True) - yield ChromaDBVectorService(settings_service=MockSettingsService(), db_path=path) + yield ChromaDBLocalVectorService(settings_service=MockSettingsService()) shutil.rmtree(path, ignore_errors=True) diff --git a/backend/src/modules/vector/ChromaDBClientVectorService.py b/backend/src/modules/vector/ChromaDBClientVectorService.py new file mode 100644 index 00000000..8ac01bb6 --- /dev/null +++ b/backend/src/modules/vector/ChromaDBClientVectorService.py @@ -0,0 +1,107 @@ +import chromadb +from chromadb import EmbeddingFunction, AsyncClientAPI +from chromadb.errors import InvalidCollectionException +from chromadb.utils import embedding_functions +from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction + +from src.modules.settings.protocols.ISettingsService import ISettingsService +from src.modules.settings.settings import SettingKey +from src.modules.vector.models.VectorDocument import VectorDocument +from src.modules.vector.models.VectorSpace import VectorSpace +from src.modules.vector.protocols.IVectorService import IVectorService + + +class ChromaDBClientVectorService(IVectorService): + def __init__(self, settings_service: ISettingsService, host: str, port: int): + self._chroma_client = None + self._settings_service = settings_service + self._host = host + self._port = port + + async def _get_client(self) -> AsyncClientAPI: + if self._chroma_client is None: + self._chroma_client = await chromadb.AsyncHttpClient(self._host, self._port, settings=chromadb.Settings( + anonymized_telemetry=False, + allow_reset=True + )) + return self._chroma_client + + async def create_vector_space(self, space: str, embedding_model: str): + embedding_function = await self._get_embedding_function(embedding_model) + client = await self._get_client() + try: + await client.get_collection(space, embedding_function) + except InvalidCollectionException: + await client.create_collection( + name=space, + embedding_function=embedding_function, + ) + + async def add_documents_to_vector_space(self, space: str, embedding_model: str, documents: list[VectorDocument]): + if len(documents) == 0: + return + + client = await self._get_client() + + collection = await client.get_collection( + space, + embedding_function=await self._get_embedding_function(embedding_model) + ) + + await collection.add( + documents=[doc.content for doc in documents], + ids=[doc.id for doc in documents], + metadatas=[doc.metadata for doc in documents] + ) + + async def delete_vector_space(self, space: str): + try: + client = await self._get_client() + await client.delete_collection(space) + except ValueError: + return + + async def get_vector_spaces(self) -> list[VectorSpace]: + client = await self._get_client() + return [VectorSpace(name=collection) for collection in await client.list_collections()] + + async def query_vector_space( + self, + space: str, + embedding_model: str, + query: str, + max_results: int + ) -> list[VectorDocument]: + try: + client = await self._get_client() + collection = await client.get_collection(space, embedding_function=await self._get_embedding_function( + embedding_model)) + except InvalidCollectionException: + return [] + + query_results = await collection.query( + query_texts=[query], + n_results=max_results + ) + + ids = query_results['ids'][0] + documents = query_results['documents'][0] + metadatas = query_results['metadatas'][0] + + return [VectorDocument( + id=ids[i], + content=documents[i], + metadata=dict(metadatas[i]) if metadatas[i] else None + ) for i in range(len(ids)) + ] + + async def _get_embedding_function(self, embedding_model_name: str) -> EmbeddingFunction: + openai_api_key = await self._settings_service.get_setting(SettingKey.OPENAI_API_KEY.key) + embedding_model_map = { + 'default': lambda: embedding_functions.DefaultEmbeddingFunction(), + 'text-embedding-3-small': lambda: OpenAIEmbeddingFunction( + api_key=openai_api_key, + model_name='text-embedding-3-small' + ) + } + return embedding_model_map[embedding_model_name or 'default']() diff --git a/backend/src/modules/vector/ChromaDBVectorService.py b/backend/src/modules/vector/ChromaDBLocalVectorService.py similarity index 98% rename from backend/src/modules/vector/ChromaDBVectorService.py rename to backend/src/modules/vector/ChromaDBLocalVectorService.py index 56bc4003..6c8e7b11 100644 --- a/backend/src/modules/vector/ChromaDBVectorService.py +++ b/backend/src/modules/vector/ChromaDBLocalVectorService.py @@ -11,7 +11,7 @@ from src.modules.vector.protocols.IVectorService import IVectorService -class ChromaDBVectorService(IVectorService): +class ChromaDBLocalVectorService(IVectorService): def __init__(self, settings_service: ISettingsService, db_path='./__chromadb'): self._settings_service = settings_service self._chroma_client = chromadb.PersistentClient( diff --git a/backend/src/modules/vector/factory.py b/backend/src/modules/vector/factory.py index 87a2101e..c22b0330 100644 --- a/backend/src/modules/vector/factory.py +++ b/backend/src/modules/vector/factory.py @@ -1,5 +1,8 @@ +import os + from src.modules.settings.protocols.ISettingsService import ISettingsService -from src.modules.vector.ChromaDBVectorService import ChromaDBVectorService +from src.modules.vector.ChromaDBClientVectorService import ChromaDBClientVectorService +from src.modules.vector.ChromaDBLocalVectorService import ChromaDBLocalVectorService from src.modules.vector.protocols.IVectorService import IVectorService @@ -8,4 +11,11 @@ def __init__(self, settings_service: ISettingsService): self._settings_service = settings_service def get(self) -> IVectorService: - return ChromaDBVectorService(settings_service=self._settings_service) + server_host = os.environ.get("CHROMADB_HOST") + server_port = os.environ.get("CHROMADB_PORT") + + if server_host is not None and len(server_host) > 0: + port = int(server_port) if server_port is not None and server_port.isdigit() else int(server_port) + return ChromaDBClientVectorService(settings_service=self._settings_service, host=server_host, port=port) + + return ChromaDBLocalVectorService(settings_service=self._settings_service) diff --git a/backend/src/modules/vector/test_ChromaDBVectorService.py b/backend/src/modules/vector/test_ChromaDBLocalVectorService.py similarity index 100% rename from backend/src/modules/vector/test_ChromaDBVectorService.py rename to backend/src/modules/vector/test_ChromaDBLocalVectorService.py