Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,12 @@ SMTP_HOST="localhost"
SMTP_PORT=587
SMTP_USER="my smtp user"
SMTP_PASSWORD="my password"

# Hostname of the ChromaDB server instance to use for file embeddings/vectorization.
# Will default to using a local ChromaDB instance using SQLite if not provided.
# More: https://docs.trychroma.com/docs/run-chroma/client-server
# CHROMADB_HOST=localhost

# (Optional) Port of ChromaDB server to connect to (if CHROMADB_HOST is also set).
# Defaults to 8000 if not provided.
# CHROMADB_PORT=8000
4 changes: 2 additions & 2 deletions backend/src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from src.common.get_timestamp import get_timestamp
from src.common.mock_services import MockSettingsService
from src.modules.vector.ChromaDBVectorService import ChromaDBVectorService
from src.modules.vector.ChromaDBLocalVectorService import ChromaDBLocalVectorService

TEST_DB_PATH = './__test_chromadb'

Expand All @@ -17,7 +17,7 @@
def vector_service():
path = TEST_DB_PATH + uuid.uuid4().hex
shutil.rmtree(path, ignore_errors=True)
yield ChromaDBVectorService(settings_service=MockSettingsService(), db_path=path)
yield ChromaDBLocalVectorService(settings_service=MockSettingsService())
shutil.rmtree(path, ignore_errors=True)


Expand Down
107 changes: 107 additions & 0 deletions backend/src/modules/vector/ChromaDBClientVectorService.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import chromadb
from chromadb import EmbeddingFunction, AsyncClientAPI
from chromadb.errors import InvalidCollectionException
from chromadb.utils import embedding_functions
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction

from src.modules.settings.protocols.ISettingsService import ISettingsService
from src.modules.settings.settings import SettingKey
from src.modules.vector.models.VectorDocument import VectorDocument
from src.modules.vector.models.VectorSpace import VectorSpace
from src.modules.vector.protocols.IVectorService import IVectorService


class ChromaDBClientVectorService(IVectorService):
def __init__(self, settings_service: ISettingsService, host: str, port: int):
self._chroma_client = None
self._settings_service = settings_service
self._host = host
self._port = port

async def _get_client(self) -> AsyncClientAPI:
if self._chroma_client is None:
self._chroma_client = await chromadb.AsyncHttpClient(self._host, self._port, settings=chromadb.Settings(
anonymized_telemetry=False,
allow_reset=True
))
return self._chroma_client

async def create_vector_space(self, space: str, embedding_model: str):
embedding_function = await self._get_embedding_function(embedding_model)
client = await self._get_client()
try:
await client.get_collection(space, embedding_function)
except InvalidCollectionException:
await client.create_collection(
name=space,
embedding_function=embedding_function,
)

async def add_documents_to_vector_space(self, space: str, embedding_model: str, documents: list[VectorDocument]):
if len(documents) == 0:
return

client = await self._get_client()

collection = await client.get_collection(
space,
embedding_function=await self._get_embedding_function(embedding_model)
)

await collection.add(
documents=[doc.content for doc in documents],
ids=[doc.id for doc in documents],
metadatas=[doc.metadata for doc in documents]
)

async def delete_vector_space(self, space: str):
try:
client = await self._get_client()
await client.delete_collection(space)
except ValueError:
return

async def get_vector_spaces(self) -> list[VectorSpace]:
client = await self._get_client()
return [VectorSpace(name=collection) for collection in await client.list_collections()]

async def query_vector_space(
self,
space: str,
embedding_model: str,
query: str,
max_results: int
) -> list[VectorDocument]:
try:
client = await self._get_client()
collection = await client.get_collection(space, embedding_function=await self._get_embedding_function(
embedding_model))
except InvalidCollectionException:
return []

query_results = await collection.query(
query_texts=[query],
n_results=max_results
)

ids = query_results['ids'][0]
documents = query_results['documents'][0]
metadatas = query_results['metadatas'][0]

return [VectorDocument(
id=ids[i],
content=documents[i],
metadata=dict(metadatas[i]) if metadatas[i] else None
) for i in range(len(ids))
]

async def _get_embedding_function(self, embedding_model_name: str) -> EmbeddingFunction:
openai_api_key = await self._settings_service.get_setting(SettingKey.OPENAI_API_KEY.key)
embedding_model_map = {
'default': lambda: embedding_functions.DefaultEmbeddingFunction(),
'text-embedding-3-small': lambda: OpenAIEmbeddingFunction(
api_key=openai_api_key,
model_name='text-embedding-3-small'
)
}
return embedding_model_map[embedding_model_name or 'default']()
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from src.modules.vector.protocols.IVectorService import IVectorService


class ChromaDBVectorService(IVectorService):
class ChromaDBLocalVectorService(IVectorService):
def __init__(self, settings_service: ISettingsService, db_path='./__chromadb'):
self._settings_service = settings_service
self._chroma_client = chromadb.PersistentClient(
Expand Down
14 changes: 12 additions & 2 deletions backend/src/modules/vector/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os

from src.modules.settings.protocols.ISettingsService import ISettingsService
from src.modules.vector.ChromaDBVectorService import ChromaDBVectorService
from src.modules.vector.ChromaDBClientVectorService import ChromaDBClientVectorService
from src.modules.vector.ChromaDBLocalVectorService import ChromaDBLocalVectorService
from src.modules.vector.protocols.IVectorService import IVectorService


Expand All @@ -8,4 +11,11 @@ def __init__(self, settings_service: ISettingsService):
self._settings_service = settings_service

def get(self) -> IVectorService:
return ChromaDBVectorService(settings_service=self._settings_service)
server_host = os.environ.get("CHROMADB_HOST")
server_port = os.environ.get("CHROMADB_PORT")

if server_host is not None and len(server_host) > 0:
port = int(server_port) if server_port is not None and server_port.isdigit() else int(server_port)
return ChromaDBClientVectorService(settings_service=self._settings_service, host=server_host, port=port)

return ChromaDBLocalVectorService(settings_service=self._settings_service)