From 8fd8f0d2fe8c844b5e3da46b9448c1838c9bbd19 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Tue, 7 Apr 2026 23:30:56 +0000 Subject: [PATCH 01/19] Move config to base dir --- src/omop_emb/__init__.py | 2 +- src/omop_emb/backends/base.py | 2 +- src/omop_emb/backends/embedding_utils.py | 2 +- src/omop_emb/backends/factory.py | 2 +- src/omop_emb/backends/faiss/faiss_backend.py | 3 +-- src/omop_emb/backends/faiss/faiss_sql.py | 2 +- src/omop_emb/backends/faiss/index_manager.py | 2 +- src/omop_emb/backends/faiss/storage_manager.py | 2 +- src/omop_emb/backends/pgvector/pgvector_backend.py | 2 +- src/omop_emb/backends/pgvector/pgvector_sql.py | 2 +- src/omop_emb/backends/registry.py | 2 +- src/omop_emb/cli.py | 2 +- src/omop_emb/{backends => }/config.py | 0 src/omop_emb/interface.py | 2 +- tests/conftest.py | 1 - tests/shared_backend_tests.py | 2 +- tests/test_config.py | 6 +++--- tests/test_interface.py | 2 +- 18 files changed, 18 insertions(+), 20 deletions(-) rename src/omop_emb/{backends => }/config.py (100%) diff --git a/src/omop_emb/__init__.py b/src/omop_emb/__init__.py index 8f791a1..6ce8b62 100644 --- a/src/omop_emb/__init__.py +++ b/src/omop_emb/__init__.py @@ -1,6 +1,6 @@ from .interface import EmbeddingInterface from .backends.base import EmbeddingConceptFilter -from .backends.config import BackendType, IndexType, MetricType +from .config import BackendType, IndexType, MetricType from .backends.factory import get_embedding_backend __all__ = [ diff --git a/src/omop_emb/backends/base.py b/src/omop_emb/backends/base.py index 1765a63..a1a811a 100644 --- a/src/omop_emb/backends/base.py +++ b/src/omop_emb/backends/base.py @@ -12,7 +12,7 @@ from omop_alchemy.cdm.model.vocabulary import Concept from .registry import ModelRegistry, ensure_model_registry_schema -from .config import BackendType, SUPPORTED_INDICES_AND_METRICS_PER_BACKEND, IndexType, MetricType +from ..config import BackendType, IndexType, MetricType from .errors import ModelRegistrationConflictError from .embedding_utils import ( EmbeddingModelRecord, diff --git a/src/omop_emb/backends/embedding_utils.py b/src/omop_emb/backends/embedding_utils.py index 5f66576..8cdb503 100644 --- a/src/omop_emb/backends/embedding_utils.py +++ b/src/omop_emb/backends/embedding_utils.py @@ -4,7 +4,7 @@ from sqlalchemy import Select from omop_alchemy.cdm.model.vocabulary import Concept -from .config import BackendType, SUPPORTED_INDICES_AND_METRICS_PER_BACKEND, IndexType, MetricType +from ..config import BackendType, IndexType, MetricType @dataclass(frozen=True) diff --git a/src/omop_emb/backends/factory.py b/src/omop_emb/backends/factory.py index 37169bc..830c0d3 100644 --- a/src/omop_emb/backends/factory.py +++ b/src/omop_emb/backends/factory.py @@ -4,7 +4,7 @@ from typing import Optional from .base import EmbeddingBackend -from .config import ( +from ..config import ( BackendType, ENV_OMOP_EMB_BACKEND, ) diff --git a/src/omop_emb/backends/faiss/faiss_backend.py b/src/omop_emb/backends/faiss/faiss_backend.py index decf1dc..3073011 100644 --- a/src/omop_emb/backends/faiss/faiss_backend.py +++ b/src/omop_emb/backends/faiss/faiss_backend.py @@ -6,7 +6,6 @@ import numpy as np from numpy import ndarray -from omop_emb.backends.config import BackendType from sqlalchemy import Engine, insert, select from sqlalchemy.orm import Session import logging @@ -23,7 +22,7 @@ add_concept_ids_to_faiss_registry, q_concept_ids_with_embeddings ) -from ..config import IndexType, MetricType, ENV_OMOP_EMB_FAISS_INDEX_DIR +from omop_emb.config import BackendType, IndexType, MetricType, ENV_OMOP_EMB_FAISS_INDEX_DIR from .storage_manager import EmbeddingStorageManager from ..base import EmbeddingBackend, require_registered_model from ..embedding_utils import ( diff --git a/src/omop_emb/backends/faiss/faiss_sql.py b/src/omop_emb/backends/faiss/faiss_sql.py index 1d0d6c6..b32ce71 100644 --- a/src/omop_emb/backends/faiss/faiss_sql.py +++ b/src/omop_emb/backends/faiss/faiss_sql.py @@ -10,7 +10,7 @@ from omop_alchemy.cdm.model.vocabulary import Concept from ..base import ConceptIDEmbeddingBase -from ..config import BackendType +from omop_emb.config import BackendType from ..registry import ModelRegistry from ..embedding_utils import EmbeddingConceptFilter diff --git a/src/omop_emb/backends/faiss/index_manager.py b/src/omop_emb/backends/faiss/index_manager.py index 52e8b9c..bd33e50 100644 --- a/src/omop_emb/backends/faiss/index_manager.py +++ b/src/omop_emb/backends/faiss/index_manager.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Optional, Generator, Tuple -from omop_emb.backends.config import IndexType, MetricType +from omop_emb.config import IndexType, MetricType import logging logger = logging.getLogger(__name__) diff --git a/src/omop_emb/backends/faiss/storage_manager.py b/src/omop_emb/backends/faiss/storage_manager.py index 44bf423..fca9331 100644 --- a/src/omop_emb/backends/faiss/storage_manager.py +++ b/src/omop_emb/backends/faiss/storage_manager.py @@ -11,7 +11,7 @@ import logging logger = logging.getLogger(__name__) -from omop_emb.backends.config import BackendType, IndexType, MetricType +from omop_emb.config import BackendType, IndexType, MetricType from .index_manager import FlatIndexManager, BaseIndexManager class EmbeddingStorageManager: diff --git a/src/omop_emb/backends/pgvector/pgvector_backend.py b/src/omop_emb/backends/pgvector/pgvector_backend.py index 92050ad..85fea8f 100644 --- a/src/omop_emb/backends/pgvector/pgvector_backend.py +++ b/src/omop_emb/backends/pgvector/pgvector_backend.py @@ -15,7 +15,7 @@ add_embeddings_to_registered_table, ) from ..registry import ModelRegistry -from ..config import BackendType, MetricType +from omop_emb.config import BackendType, MetricType from ..base import EmbeddingBackend, require_registered_model from ..embedding_utils import ( EmbeddingBackendCapabilities, diff --git a/src/omop_emb/backends/pgvector/pgvector_sql.py b/src/omop_emb/backends/pgvector/pgvector_sql.py index 855c7b1..3d19beb 100644 --- a/src/omop_emb/backends/pgvector/pgvector_sql.py +++ b/src/omop_emb/backends/pgvector/pgvector_sql.py @@ -11,7 +11,7 @@ import logging from numpy import ndarray -from ..config import BackendType, IndexType, MetricType +from omop_emb.config import BackendType, IndexType, MetricType from ..registry import ModelRegistry from ..base import ConceptIDEmbeddingBase from ..embedding_utils import EmbeddingConceptFilter, get_similarity_from_distance diff --git a/src/omop_emb/backends/registry.py b/src/omop_emb/backends/registry.py index 868b5cb..9b16e23 100644 --- a/src/omop_emb/backends/registry.py +++ b/src/omop_emb/backends/registry.py @@ -5,7 +5,7 @@ from orm_loader.helpers import Base -from .config import ( +from ..config import ( is_index_type_supported_for_backend, get_supported_index_types_for_backend, IndexType, diff --git a/src/omop_emb/cli.py b/src/omop_emb/cli.py index 28957dd..5eb1170 100644 --- a/src/omop_emb/cli.py +++ b/src/omop_emb/cli.py @@ -16,7 +16,7 @@ EmbeddingConceptFilter, ) from omop_emb.interface import EmbeddingInterface -from omop_emb.backends.config import BackendType, IndexType +from omop_emb.config import IndexType app = typer.Typer() logger = get_logger(__name__) diff --git a/src/omop_emb/backends/config.py b/src/omop_emb/config.py similarity index 100% rename from src/omop_emb/backends/config.py rename to src/omop_emb/config.py diff --git a/src/omop_emb/interface.py b/src/omop_emb/interface.py index 267c645..2e5ae42 100644 --- a/src/omop_emb/interface.py +++ b/src/omop_emb/interface.py @@ -16,7 +16,7 @@ EmbeddingModelRecord, get_embedding_backend, ) -from .backends.config import IndexType, MetricType +from .config import IndexType, MetricType @dataclass diff --git a/tests/conftest.py b/tests/conftest.py index 44c9480..8ec92d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,6 @@ from omop_emb.backends.pgvector import PGVectorEmbeddingBackend from omop_emb.backends.registry import ensure_model_registry_schema, ModelRegistry from omop_emb.interface import EmbeddingInterface -from omop_emb.backends.config import IndexType TEST_DB_NAME = os.getenv("TEST_DATABASE_NAME", "test_omop_emb") diff --git a/tests/shared_backend_tests.py b/tests/shared_backend_tests.py index 2b6c8f8..ddb00a9 100644 --- a/tests/shared_backend_tests.py +++ b/tests/shared_backend_tests.py @@ -3,7 +3,7 @@ import pytest import numpy as np -from omop_emb.backends.config import IndexType, MetricType +from omop_emb.config import IndexType, MetricType from omop_emb.backends.embedding_utils import EmbeddingConceptFilter from omop_emb.backends.errors import ModelRegistrationConflictError from .conftest import CONCEPTS, MODEL_NAME, EMBEDDING_DIM, TEST_CONCEPT_EMB diff --git a/tests/test_config.py b/tests/test_config.py index 824a09a..8e247ea 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,7 +3,7 @@ import pytest from omop_emb.backends.factory import normalize_backend_name, get_embedding_backend -from omop_emb.backends.config import BackendType, IndexType, MetricType +from omop_emb.config import BackendType, IndexType, MetricType @pytest.mark.unit @@ -22,13 +22,13 @@ def test_get_faiss_backend(self, tmp_path): def test_faiss_supports_flat_index(self): """Test FAISS supports FLAT index.""" - from omop_emb.backends.config import is_index_type_supported_for_backend + from omop_emb.config import is_index_type_supported_for_backend assert is_index_type_supported_for_backend(BackendType.FAISS, IndexType.FLAT) def test_faiss_supports_cosine_metric(self): """Test FAISS supports COSINE metric.""" - from omop_emb.backends.config import is_supported_index_metric_combination_for_backend + from omop_emb.config import is_supported_index_metric_combination_for_backend assert is_supported_index_metric_combination_for_backend( BackendType.FAISS, diff --git a/tests/test_interface.py b/tests/test_interface.py index a49ff0f..c38916a 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from omop_emb.interface import EmbeddingInterface -from omop_emb.backends.config import IndexType, MetricType +from omop_emb.config import IndexType, MetricType from omop_emb.backends.base import NearestConceptMatch from .conftest import CONCEPTS, MODEL_NAME, EMBEDDING_DIM From 4ac5e08c8351e5605c9a6ae44062b5cba65b0f7d Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Tue, 7 Apr 2026 23:35:50 +0000 Subject: [PATCH 02/19] Move embedding utils in utils directory --- src/omop_emb/backends/__init__.py | 10 ---------- src/omop_emb/backends/base.py | 2 +- src/omop_emb/backends/faiss/faiss_backend.py | 2 +- src/omop_emb/backends/faiss/faiss_sql.py | 2 +- src/omop_emb/backends/pgvector/pgvector_backend.py | 2 +- src/omop_emb/backends/pgvector/pgvector_sql.py | 2 +- src/omop_emb/interface.py | 4 +++- src/omop_emb/utils/__init__.py | 0 src/omop_emb/{backends => utils}/embedding_utils.py | 0 tests/shared_backend_tests.py | 2 +- 10 files changed, 9 insertions(+), 17 deletions(-) create mode 100644 src/omop_emb/utils/__init__.py rename src/omop_emb/{backends => utils}/embedding_utils.py (100%) diff --git a/src/omop_emb/backends/__init__.py b/src/omop_emb/backends/__init__.py index ede168c..31bdc40 100644 --- a/src/omop_emb/backends/__init__.py +++ b/src/omop_emb/backends/__init__.py @@ -1,12 +1,6 @@ from .base import ( EmbeddingBackend, ) -from .embedding_utils import ( - EmbeddingModelRecord, - EmbeddingBackendCapabilities, - EmbeddingConceptFilter, - NearestConceptMatch, -) from .errors import ( EmbeddingBackendConfigurationError, EmbeddingBackendDependencyError, @@ -20,10 +14,6 @@ __all__ = [ "EmbeddingBackend", - "EmbeddingBackendCapabilities", - "EmbeddingConceptFilter", - "EmbeddingModelRecord", - "NearestConceptMatch", "EmbeddingBackendConfigurationError", "EmbeddingBackendDependencyError", "EmbeddingBackendError", diff --git a/src/omop_emb/backends/base.py b/src/omop_emb/backends/base.py index a1a811a..d581d0c 100644 --- a/src/omop_emb/backends/base.py +++ b/src/omop_emb/backends/base.py @@ -14,7 +14,7 @@ from .registry import ModelRegistry, ensure_model_registry_schema from ..config import BackendType, IndexType, MetricType from .errors import ModelRegistrationConflictError -from .embedding_utils import ( +from omop_emb.utils.embedding_utils import ( EmbeddingModelRecord, EmbeddingConceptFilter, NearestConceptMatch, diff --git a/src/omop_emb/backends/faiss/faiss_backend.py b/src/omop_emb/backends/faiss/faiss_backend.py index 3073011..e2e50e4 100644 --- a/src/omop_emb/backends/faiss/faiss_backend.py +++ b/src/omop_emb/backends/faiss/faiss_backend.py @@ -25,7 +25,7 @@ from omop_emb.config import BackendType, IndexType, MetricType, ENV_OMOP_EMB_FAISS_INDEX_DIR from .storage_manager import EmbeddingStorageManager from ..base import EmbeddingBackend, require_registered_model -from ..embedding_utils import ( +from omop_emb.utils.embedding_utils import ( EmbeddingBackendCapabilities, EmbeddingConceptFilter, EmbeddingModelRecord, diff --git a/src/omop_emb/backends/faiss/faiss_sql.py b/src/omop_emb/backends/faiss/faiss_sql.py index b32ce71..b24467b 100644 --- a/src/omop_emb/backends/faiss/faiss_sql.py +++ b/src/omop_emb/backends/faiss/faiss_sql.py @@ -12,7 +12,7 @@ from ..base import ConceptIDEmbeddingBase from omop_emb.config import BackendType from ..registry import ModelRegistry -from ..embedding_utils import EmbeddingConceptFilter +from omop_emb.utils.embedding_utils import EmbeddingConceptFilter logger = logging.getLogger(__name__) diff --git a/src/omop_emb/backends/pgvector/pgvector_backend.py b/src/omop_emb/backends/pgvector/pgvector_backend.py index 85fea8f..e4c88da 100644 --- a/src/omop_emb/backends/pgvector/pgvector_backend.py +++ b/src/omop_emb/backends/pgvector/pgvector_backend.py @@ -17,7 +17,7 @@ from ..registry import ModelRegistry from omop_emb.config import BackendType, MetricType from ..base import EmbeddingBackend, require_registered_model -from ..embedding_utils import ( +from omop_emb.utils.embedding_utils import ( EmbeddingBackendCapabilities, EmbeddingConceptFilter, EmbeddingModelRecord, diff --git a/src/omop_emb/backends/pgvector/pgvector_sql.py b/src/omop_emb/backends/pgvector/pgvector_sql.py index 3d19beb..f559d23 100644 --- a/src/omop_emb/backends/pgvector/pgvector_sql.py +++ b/src/omop_emb/backends/pgvector/pgvector_sql.py @@ -14,7 +14,7 @@ from omop_emb.config import BackendType, IndexType, MetricType from ..registry import ModelRegistry from ..base import ConceptIDEmbeddingBase -from ..embedding_utils import EmbeddingConceptFilter, get_similarity_from_distance +from omop_emb.utils.embedding_utils import EmbeddingConceptFilter, get_similarity_from_distance logger = logging.getLogger(__name__) diff --git a/src/omop_emb/interface.py b/src/omop_emb/interface.py index 2e5ae42..8e32fae 100644 --- a/src/omop_emb/interface.py +++ b/src/omop_emb/interface.py @@ -12,9 +12,11 @@ from .backends import ( EmbeddingBackend, + get_embedding_backend, +) +from omop_emb.utils.embedding_utils import ( EmbeddingConceptFilter, EmbeddingModelRecord, - get_embedding_backend, ) from .config import IndexType, MetricType diff --git a/src/omop_emb/utils/__init__.py b/src/omop_emb/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/omop_emb/backends/embedding_utils.py b/src/omop_emb/utils/embedding_utils.py similarity index 100% rename from src/omop_emb/backends/embedding_utils.py rename to src/omop_emb/utils/embedding_utils.py diff --git a/tests/shared_backend_tests.py b/tests/shared_backend_tests.py index ddb00a9..f346da4 100644 --- a/tests/shared_backend_tests.py +++ b/tests/shared_backend_tests.py @@ -4,7 +4,7 @@ import numpy as np from omop_emb.config import IndexType, MetricType -from omop_emb.backends.embedding_utils import EmbeddingConceptFilter +from omop_emb.utils.embedding_utils import EmbeddingConceptFilter from omop_emb.backends.errors import ModelRegistrationConflictError from .conftest import CONCEPTS, MODEL_NAME, EMBEDDING_DIM, TEST_CONCEPT_EMB From fb3067fb97a3ba4b7f2234cf4c34bf194cf6cbc6 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Tue, 7 Apr 2026 23:37:43 +0000 Subject: [PATCH 03/19] Centralised errors in utils --- src/omop_emb/backends/__init__.py | 10 ---------- src/omop_emb/backends/base.py | 2 +- src/omop_emb/backends/factory.py | 2 +- src/omop_emb/backends/faiss/faiss_backend.py | 1 - src/omop_emb/{backends => utils}/errors.py | 0 tests/shared_backend_tests.py | 2 +- 6 files changed, 3 insertions(+), 14 deletions(-) rename src/omop_emb/{backends => utils}/errors.py (100%) diff --git a/src/omop_emb/backends/__init__.py b/src/omop_emb/backends/__init__.py index 31bdc40..4096d1c 100644 --- a/src/omop_emb/backends/__init__.py +++ b/src/omop_emb/backends/__init__.py @@ -1,12 +1,6 @@ from .base import ( EmbeddingBackend, ) -from .errors import ( - EmbeddingBackendConfigurationError, - EmbeddingBackendDependencyError, - EmbeddingBackendError, - UnknownEmbeddingBackendError, -) from .factory import ( get_embedding_backend, normalize_backend_name, @@ -14,10 +8,6 @@ __all__ = [ "EmbeddingBackend", - "EmbeddingBackendConfigurationError", - "EmbeddingBackendDependencyError", - "EmbeddingBackendError", - "UnknownEmbeddingBackendError", "get_embedding_backend", "normalize_backend_name", "FaissEmbeddingBackend", diff --git a/src/omop_emb/backends/base.py b/src/omop_emb/backends/base.py index d581d0c..955c2cb 100644 --- a/src/omop_emb/backends/base.py +++ b/src/omop_emb/backends/base.py @@ -13,7 +13,7 @@ from .registry import ModelRegistry, ensure_model_registry_schema from ..config import BackendType, IndexType, MetricType -from .errors import ModelRegistrationConflictError +from omop_emb.utils.errors import ModelRegistrationConflictError from omop_emb.utils.embedding_utils import ( EmbeddingModelRecord, EmbeddingConceptFilter, diff --git a/src/omop_emb/backends/factory.py b/src/omop_emb/backends/factory.py index 830c0d3..4b921f9 100644 --- a/src/omop_emb/backends/factory.py +++ b/src/omop_emb/backends/factory.py @@ -8,7 +8,7 @@ BackendType, ENV_OMOP_EMB_BACKEND, ) -from .errors import ( +from omop_emb.utils.errors import ( EmbeddingBackendConfigurationError, EmbeddingBackendDependencyError, UnknownEmbeddingBackendError, diff --git a/src/omop_emb/backends/faiss/faiss_backend.py b/src/omop_emb/backends/faiss/faiss_backend.py index e2e50e4..045e789 100644 --- a/src/omop_emb/backends/faiss/faiss_backend.py +++ b/src/omop_emb/backends/faiss/faiss_backend.py @@ -15,7 +15,6 @@ from omop_alchemy.cdm.model.vocabulary import Concept from omop_emb.backends.registry import ModelRegistry -from ..errors import EmbeddingBackendConfigurationError from .faiss_sql import ( FAISSConceptIDEmbeddingRegistry, create_faiss_embedding_registry_table, diff --git a/src/omop_emb/backends/errors.py b/src/omop_emb/utils/errors.py similarity index 100% rename from src/omop_emb/backends/errors.py rename to src/omop_emb/utils/errors.py diff --git a/tests/shared_backend_tests.py b/tests/shared_backend_tests.py index f346da4..c0fb70d 100644 --- a/tests/shared_backend_tests.py +++ b/tests/shared_backend_tests.py @@ -5,7 +5,7 @@ from omop_emb.config import IndexType, MetricType from omop_emb.utils.embedding_utils import EmbeddingConceptFilter -from omop_emb.backends.errors import ModelRegistrationConflictError +from omop_emb.utils.errors import ModelRegistrationConflictError from .conftest import CONCEPTS, MODEL_NAME, EMBEDDING_DIM, TEST_CONCEPT_EMB From 57085599da2f5921e734b28dbcc9a72454d63ea3 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Wed, 8 Apr 2026 00:23:19 +0000 Subject: [PATCH 04/19] Have the embedding_table_cache per file to circumvent longliving declarative base --- src/omop_emb/backends/base.py | 62 ++++++++++--------- src/omop_emb/backends/faiss/faiss_backend.py | 4 +- src/omop_emb/backends/faiss/faiss_sql.py | 26 ++++---- .../backends/pgvector/pgvector_backend.py | 9 +-- .../backends/pgvector/pgvector_sql.py | 44 ++++++++----- src/omop_emb/cli.py | 7 +-- src/omop_emb/interface.py | 6 +- tests/test_dummy.py | 3 - 8 files changed, 81 insertions(+), 80 deletions(-) delete mode 100644 tests/test_dummy.py diff --git a/src/omop_emb/backends/base.py b/src/omop_emb/backends/base.py index 955c2cb..4a95de5 100644 --- a/src/omop_emb/backends/base.py +++ b/src/omop_emb/backends/base.py @@ -71,12 +71,12 @@ class EmbeddingBackend(ABC, Generic[T]): """ def __init__(self): super().__init__() - self._model_cache: dict[str, Type[T]] = {} + self._embedding_table_cache: dict[str, Type[T]] = {} @property - def model_cache(self) -> Dict[str, Type[T]]: + def embedding_table_cache(self) -> Dict[str, Type[T]]: """In-memory cache of dynamically generated ORM classes for embedding tables.""" - return self._model_cache + return self._embedding_table_cache @property @abstractmethod @@ -110,10 +110,8 @@ def initialise_store(self, engine: Engine) -> None: session.commit() for model_entry in existing_models: - if model_entry.model_name not in self.model_cache: - # Create the class and cache it - dynamic_table = self._create_storage_table(engine=engine, entry=model_entry) - self.model_cache[model_entry.model_name] = dynamic_table + dynamic_table = self._create_storage_table(engine=engine, entry=model_entry) + self.embedding_table_cache[model_entry.model_name] = dynamic_table def list_registered_models(self, session: Session) -> Sequence[EmbeddingModelRecord]: """Return all embedding models known to this backend.""" @@ -152,22 +150,20 @@ def get_concepts_without_embedding( limit: Optional[int] = None, ) -> Mapping[int, str]: """Return concept IDs and names for concepts that do not have embeddings.""" - query = self.get_concepts_without_embedding_query( - session=session, + query = self.q_get_concepts_without_embedding( model_name=model_name, concept_filter=concept_filter, limit=limit ) return {row.concept_id: row.concept_name for row in session.execute(query)} - def get_concepts_without_embedding_query( + def q_get_concepts_without_embedding( self, - session: Session, model_name: str, concept_filter: Optional[EmbeddingConceptFilter] = None, limit: Optional[int] = None, ) -> Select: - embedding_table = self._get_embedding_table(session=session, model_name=model_name) + embedding_table = self.get_embedding_table(model_name=model_name) subq = select(1).where(embedding_table.concept_id == Concept.concept_id) query = ( @@ -186,7 +182,7 @@ def get_concepts_without_embedding_count( model_name: str, concept_filter: Optional[EmbeddingConceptFilter] = None, ) -> int: - embedding_table = self._get_embedding_table(session=session, model_name=model_name) + embedding_table = self.get_embedding_table(model_name=model_name) subq = select(1).where(embedding_table.concept_id == Concept.concept_id) query = ( @@ -269,7 +265,7 @@ def register_model( session.commit() dynamic_table = self._create_storage_table(engine, new_entry) - self.model_cache[model_name] = dynamic_table + self.embedding_table_cache[model_name] = dynamic_table return self._record_from_registry_row(new_entry) @abstractmethod @@ -378,30 +374,38 @@ def refresh_model_index(self, session: Session, model_name: str) -> None: return None def has_any_embeddings(self, session: Session, model_name: str) -> bool: - embedding_table = self._get_embedding_table( - session=session, + embedding_table = self.get_embedding_table( model_name=model_name, ) return session.query(embedding_table.concept_id).limit(1).first() is not None - def _get_embedding_table( + def get_embedding_table( self, - session: Session, model_name: str, ) -> Type[T]: - embedding_table = self.model_cache.get(model_name) + """Return dynamically created ORM class for the embedding table associated with the given model name. + This relies on the `initialise_store` method having been called to populate the `embedding_table_cache`. + If the requested model name is not found in the cache, this indicates a logic error in the calling code (e.g., not initializing the store, or requesting a table for a model that hasn't been registered) rather than a user input error, so a ValueError is raised. + + Parameters + ---------- + model_name : str + The name of the embedding model whose associated embedding table class is to be retrieved. + Returns + ------- + Type[T] + The dynamically generated SQLAlchemy ORM class corresponding to the embedding table for the specified model. + + Raises + ------ + ValueError + If the model name is not found in the embedding table cache, indicating that the store was not properly initialized or the model has not been registered. + """ + embedding_table = self.embedding_table_cache.get(model_name) if embedding_table is not None: return embedding_table - - bind = session.get_bind() - if bind is None: - raise RuntimeError("Session is not bound to an engine.") - assert isinstance(bind, Engine), f"Expected session bind to be an Engine. Got {type(bind)}" - self.initialise_store(bind) # Ensure tables are created and cache is populated - embedding_table = self.model_cache.get(model_name) - if embedding_table is None: - raise ValueError(f"Embedding model '{model_name}' not found in cache.") - return embedding_table + else: + raise ValueError(f"Model '{model_name}' not found in cache. Ensure the model is registered and the store is initialized.") @staticmethod def _coerce_registry_metadata(value: object) -> Mapping[str, object]: diff --git a/src/omop_emb/backends/faiss/faiss_backend.py b/src/omop_emb/backends/faiss/faiss_backend.py index 045e789..3f05af2 100644 --- a/src/omop_emb/backends/faiss/faiss_backend.py +++ b/src/omop_emb/backends/faiss/faiss_backend.py @@ -248,7 +248,7 @@ def upsert_embeddings( add_concept_ids_to_faiss_registry( concept_ids=concept_id_tuple, session=session, - registered_table=self._get_embedding_table(session, model_name), + registered_table=self.get_embedding_table(model_name), ) except Exception as e: session.rollback() @@ -292,7 +292,7 @@ def get_nearest_concepts( self.validate_embeddings(embeddings=query_embeddings, dimensions=model_record.dimensions) q_permitted_concept_ids = q_concept_ids_with_embeddings( - embedding_table=self._get_embedding_table(session, model_name), + embedding_table=self.get_embedding_table(model_name), concept_filter=concept_filter, limit=None ) diff --git a/src/omop_emb/backends/faiss/faiss_sql.py b/src/omop_emb/backends/faiss/faiss_sql.py index b24467b..b99e050 100644 --- a/src/omop_emb/backends/faiss/faiss_sql.py +++ b/src/omop_emb/backends/faiss/faiss_sql.py @@ -17,6 +17,9 @@ logger = logging.getLogger(__name__) +_FAISS_REGISTRY_TABLE_CACHE: dict[str, type["FAISSConceptIDEmbeddingRegistry"]] = {} + + class FAISSConceptIDEmbeddingRegistry(ConceptIDEmbeddingBase, Base): """Registry table to track which concept_ids are present in the FAISS/H5 storage for a specific model.""" __abstract__ = True @@ -31,16 +34,23 @@ def create_faiss_embedding_registry_table( """ tablename = model_registry_entry.storage_identifier + cached_table = _FAISS_REGISTRY_TABLE_CACHE.get(tablename) + if cached_table is not None: + Base.metadata.create_all(engine, tables=[cached_table.__table__]) # type: ignore[arg-type] + return cached_table + mapping_table = type( - tablename, + f"FAISSConceptIDEmbeddingRegistry_{tablename}", (FAISSConceptIDEmbeddingRegistry, ), { "__tablename__": tablename, "__table_args__": {"extend_existing": True}, + "__module__": __name__, }, ) Base.metadata.create_all(engine, tables=[mapping_table.__table__]) # type: ignore[arg-type] + _FAISS_REGISTRY_TABLE_CACHE[tablename] = mapping_table logger.debug( f"Initialized {FAISSConceptIDEmbeddingRegistry.__name__} table for model '{model_registry_entry.model_name}'", @@ -48,20 +58,6 @@ def create_faiss_embedding_registry_table( return mapping_table -def initialise_faiss_mapping_tables( - engine: Engine, - model_cache: dict[str, Type[FAISSConceptIDEmbeddingRegistry]], -) -> None: - with Session(engine, expire_on_commit=False) as session: - existing_models = session.scalars(select(ModelRegistry).where(ModelRegistry.backend_type == BackendType.FAISS.value)).all() - session.commit() - - for model_entry in existing_models: - if model_entry.model_name not in model_cache: - # Create the class and cache it - dynamic_table = create_faiss_embedding_registry_table(engine=engine, model_registry_entry=model_entry) - model_cache[model_entry.model_name] = dynamic_table - def add_concept_ids_to_faiss_registry( concept_ids: tuple[int, ...], session: Session, diff --git a/src/omop_emb/backends/pgvector/pgvector_backend.py b/src/omop_emb/backends/pgvector/pgvector_backend.py index e4c88da..394ed92 100644 --- a/src/omop_emb/backends/pgvector/pgvector_backend.py +++ b/src/omop_emb/backends/pgvector/pgvector_backend.py @@ -78,8 +78,7 @@ def upsert_embeddings( dimensions=model_record.dimensions, ) - table = self._get_embedding_table( - session=session, + table = self.get_embedding_table( model_name=model_name, ) @@ -108,8 +107,7 @@ def get_embeddings_by_concept_ids( if not concept_id_tuple: return {} - embedding_table = self._get_embedding_table( - session=session, + embedding_table = self.get_embedding_table( model_name=model_name, ) query = q_embedding_vectors_by_concept_ids( @@ -142,8 +140,7 @@ def get_nearest_concepts( k: int = 10, ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: - embedding_table = self._get_embedding_table( - session=session, + embedding_table = self.get_embedding_table( model_name=model_name, ) self.validate_embeddings(embeddings=query_embeddings, dimensions=model_record.dimensions) diff --git a/src/omop_emb/backends/pgvector/pgvector_sql.py b/src/omop_emb/backends/pgvector/pgvector_sql.py index f559d23..371ebae 100644 --- a/src/omop_emb/backends/pgvector/pgvector_sql.py +++ b/src/omop_emb/backends/pgvector/pgvector_sql.py @@ -19,6 +19,9 @@ logger = logging.getLogger(__name__) +_PGVECTOR_TABLE_CACHE: dict[str, type["PGVectorConceptIDEmbeddingTable"]] = {} + + class PGVectorConceptIDEmbeddingTable(ConceptIDEmbeddingBase, Base): """Abstract base for Postgres linter support.""" __abstract__ = True @@ -31,33 +34,28 @@ def create_pg_embedding_table( tablename = model_registry_entry.storage_identifier dimensions = model_registry_entry.dimensions + + cached_table = _PGVECTOR_TABLE_CACHE.get(tablename) + if cached_table is not None: + Base.metadata.create_all(engine, tables=[cached_table.__table__]) # type: ignore[arg-type] + create_indices_for_embedding_table(engine, model_registry_entry) + return cached_table table = type( - tablename, + f"PGVectorConceptIDEmbeddingTable_{tablename}", (PGVectorConceptIDEmbeddingTable,), # It already inherits from Base and ConceptIDBase { "__tablename__": tablename, # We override the attribute with the specific dimension-aware column "embedding": mapped_column(Vector(dimensions), nullable=False, index=False), "__table_args__": {"extend_existing": True}, + "__module__": __name__, } ) Base.metadata.create_all(engine, tables=[table.__table__]) # type: ignore[arg-type] + _PGVECTOR_TABLE_CACHE[tablename] = table - with engine.connect() as conn: - inspector = inspect(conn) - existing_indexes = [idx['name'] for idx in inspector.get_indexes(model_registry_entry.storage_identifier)] - - create_index_sql = _create_index_sql( - model_registry_entry.storage_identifier, - IndexType(model_registry_entry.index_type), - ) - if ( - _index_from_storage_identifier(model_registry_entry.storage_identifier) not in existing_indexes and - create_index_sql is not None - ): - conn.execute(text(create_index_sql)) - conn.commit() + create_indices_for_embedding_table(engine, model_registry_entry) logger.debug( f"Initialized {PGVectorConceptIDEmbeddingTable.__name__} table for model '{model_registry_entry.model_name}' with dimensions {model_registry_entry.dimensions} using index_method={model_registry_entry.index_type}", @@ -99,7 +97,21 @@ def add_embeddings_to_registered_table( session.commit() -def _create_index_sql(table_name: str, index_type: IndexType) -> Optional[str]: +def create_indices_for_embedding_table(engine: Engine, model_registry_entry: ModelRegistry) -> None: + with engine.connect() as conn: + inspector = inspect(conn) + existing_indexes = [idx['name'] for idx in inspector.get_indexes(model_registry_entry.storage_identifier)] + + if _index_from_storage_identifier(model_registry_entry.storage_identifier) not in existing_indexes: + query = q_create_index_sql( + model_registry_entry.storage_identifier, + IndexType(model_registry_entry.index_type), + ) + if query is not None: + conn.execute(text(query)) + conn.commit() + +def q_create_index_sql(table_name: str, index_type: IndexType) -> Optional[str]: if index_type == IndexType.FLAT: pass # No additional index needed for flat, as the vector column itself can be used for sequential scan else: diff --git a/src/omop_emb/cli.py b/src/omop_emb/cli.py index 5eb1170..49b7c8e 100644 --- a/src/omop_emb/cli.py +++ b/src/omop_emb/cli.py @@ -12,9 +12,7 @@ from tqdm import tqdm import typer -from omop_emb.backends import ( - EmbeddingConceptFilter, -) +from omop_emb.utils.embedding_utils import EmbeddingConceptFilter from omop_emb.interface import EmbeddingInterface from omop_emb.config import IndexType @@ -108,8 +106,7 @@ def add_embeddings( model_name=model, concept_filter=concept_filter, ) - concepts_without_embedding = interface.get_concepts_without_embedding_query( - session=reader, + concepts_without_embedding = interface.q_get_concepts_without_embedding( model_name=model, concept_filter=concept_filter, limit=num_embeddings, diff --git a/src/omop_emb/interface.py b/src/omop_emb/interface.py index 8e32fae..e62645d 100644 --- a/src/omop_emb/interface.py +++ b/src/omop_emb/interface.py @@ -339,16 +339,14 @@ def get_concepts_without_embedding( limit=limit, ) - def get_concepts_without_embedding_query( + def q_get_concepts_without_embedding( self, *, - session: Session, model_name: str, concept_filter: Optional[EmbeddingConceptFilter] = None, limit: Optional[int] = None, ) -> Select: - return self.backend.get_concepts_without_embedding_query( - session=session, + return self.backend.q_get_concepts_without_embedding( model_name=model_name, concept_filter=concept_filter, limit=limit, diff --git a/tests/test_dummy.py b/tests/test_dummy.py deleted file mode 100644 index d7f3c09..0000000 --- a/tests/test_dummy.py +++ /dev/null @@ -1,3 +0,0 @@ -def test_placeholder(): - # This test always passes to keep the CI happy until you write real tests - assert True From 50262077d89379e838da8bef1a6036d0c7c52969 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Wed, 8 Apr 2026 04:23:40 +0000 Subject: [PATCH 05/19] refactor(backends,registry,tests): decouple local model registry, add composite model identity, and align backend/index APIs --- src/omop_emb/backends/base.py | 292 ++++++++---------- src/omop_emb/backends/factory.py | 8 +- src/omop_emb/backends/faiss/faiss_backend.py | 158 ++-------- src/omop_emb/backends/faiss/faiss_sql.py | 9 +- .../backends/pgvector/pgvector_backend.py | 53 +--- .../backends/pgvector/pgvector_sql.py | 27 +- src/omop_emb/cli.py | 8 - src/omop_emb/config.py | 1 + src/omop_emb/interface.py | 89 +++--- src/omop_emb/model_registry/__init__.py | 2 + .../model_registry_cdm.py} | 37 ++- .../model_registry/model_registry_manager.py | 167 ++++++++++ .../model_registry/model_registry_types.py | 26 ++ src/omop_emb/utils/embedding_utils.py | 39 --- tests/conftest.py | 22 +- tests/shared_backend_tests.py | 67 ++-- tests/test_config.py | 2 +- tests/test_fixtures.py | 4 - tests/test_interface.py | 48 ++- 19 files changed, 538 insertions(+), 521 deletions(-) create mode 100644 src/omop_emb/model_registry/__init__.py rename src/omop_emb/{backends/registry.py => model_registry/model_registry_cdm.py} (74%) create mode 100644 src/omop_emb/model_registry/model_registry_manager.py create mode 100644 src/omop_emb/model_registry/model_registry_types.py diff --git a/src/omop_emb/backends/base.py b/src/omop_emb/backends/base.py index 4a95de5..2061b64 100644 --- a/src/omop_emb/backends/base.py +++ b/src/omop_emb/backends/base.py @@ -2,8 +2,9 @@ from abc import ABC, abstractmethod from typing import Mapping, Optional, Sequence, Union, Type, TypeVar, Generic, Dict, Any, Callable, Tuple -import re +from pathlib import Path from functools import wraps +import os from numpy import ndarray from sqlalchemy import Engine, select, Integer, ForeignKey, func, Select, text @@ -11,26 +12,28 @@ from omop_alchemy.cdm.model.vocabulary import Concept -from .registry import ModelRegistry, ensure_model_registry_schema -from ..config import BackendType, IndexType, MetricType -from omop_emb.utils.errors import ModelRegistrationConflictError +from ..model_registry import EmbeddingModelRecord, ModelRegistryManager +from ..config import BackendType, IndexType, MetricType, ENV_BASE_STORAGE_DIR from omop_emb.utils.embedding_utils import ( - EmbeddingModelRecord, EmbeddingConceptFilter, NearestConceptMatch, - EmbeddingBackendCapabilities ) def require_registered_model(func: Callable) -> Callable: @wraps(func) - def wrapper(self, session: Session, model_name: str, *args, **kwargs) -> Any: - record = self.get_registered_model(session, model_name) + def wrapper( + self, + model_name: str, + index_type: IndexType, + *args, **kwargs + ) -> Any: + record = self.get_registered_model(model_name=model_name, index_type=index_type) if record is None: raise ValueError( f"Embedding model '{model_name}' is not registered in the FAISS backend." ) - return func(self, session, model_name, record, *args, **kwargs) + return func(self, model_name, index_type, record, *args, **kwargs) return wrapper @@ -54,12 +57,7 @@ class EmbeddingBackend(ABC, Generic[T]): ------------ - Keep embedding generation separate from embedding persistence. - Support multiple storage/retrieval implementations behind one contract. - - Preserve the current operational needs of ``omop-graph``: - - model registration - - embedding population - - embedding lookup by concept ID - - similarity lookup over a set of concept IDs - - nearest-neighbor retrieval with OMOP-oriented filters + - Local storage of model registry metadata, but allow flexible storage of the actual embeddings (e.g., relational tables, FAISS indices, files in object storage, etc.) Notes ----- @@ -69,14 +67,40 @@ class EmbeddingBackend(ABC, Generic[T]): to validate models, resolve concept metadata, and apply domain/vocabulary filters. """ - def __init__(self): + DEFAULT_BASE_STORAGE_DIR = str(Path.home() / ".omop_emb") + def __init__( + self, + storage_base_dir: Optional[str | Path] = None, + registry_db_name: Optional[str] = None, + ): super().__init__() - self._embedding_table_cache: dict[str, Type[T]] = {} + + # Local storage for model registry database and more + storage_base_dir = storage_base_dir or os.getenv(ENV_BASE_STORAGE_DIR) or self.DEFAULT_BASE_STORAGE_DIR + self._storage_base_dir = Path(storage_base_dir) + self._storage_base_dir.mkdir(parents=True, exist_ok=True) + + + self._embedding_table_cache: dict[Tuple[str, BackendType, IndexType], Type[T]] = {} + self._embedding_model_registry = ModelRegistryManager( + base_dir=self.storage_base_dir, + db_file=registry_db_name, + ) + + @property + def storage_base_dir(self) -> Path: + """Base directory for any local storage needs of the backend, such as model registry metadata or file-based embedding storage.""" + return self._storage_base_dir @property - def embedding_table_cache(self) -> Dict[str, Type[T]]: + def embedding_table_cache(self) -> Dict[Tuple[str, BackendType, IndexType], Type[T]]: """In-memory cache of dynamically generated ORM classes for embedding tables.""" return self._embedding_table_cache + + @property + def embedding_model_registry(self) -> ModelRegistryManager: + """Manager for embedding model metadata and registry operations.""" + return self._embedding_model_registry @property @abstractmethod @@ -87,71 +111,84 @@ def backend_type(self) -> BackendType: def backend_name(self) -> str: """Stable identifier for this backend implementation.""" return self.backend_type.value + + def _cache_model_record( + self, + engine: Engine, + model_record: EmbeddingModelRecord, + ) -> None: + """Cache the given model record in the embedding table cache. This is used to keep track of which models are registered and their associated metadata.""" + dynamic_table = self._create_storage_table(engine=engine, model_record=model_record) + storage_key = (model_record.model_name, self.backend_type, model_record.index_type) + self.embedding_table_cache[storage_key] = dynamic_table - @property + @abstractmethod - def capabilities(self) -> EmbeddingBackendCapabilities: - """Capability flags describing what this backend can do.""" + def _create_storage_table(self, engine: Engine, model_record: EmbeddingModelRecord) -> Type[T]: + """Backend-specific logic to create the dynamic SQLAlchemy class.""" def initialise_store(self, engine: Engine) -> None: """ - Prepare any required storage structures. - - Examples: - - create registry tables - - warm caches from a registry - - create directories or sidecar files + Initialise the model registry and populate the embedding table cache for existing models. + Can extend it to include any backend-specific setup steps (e.g., creating extensions, staging directories) as needed. """ + self.pre_initialise_store(engine) - with Session(engine, expire_on_commit=False) as session: - existing_models = session.scalars(select(ModelRegistry).where(ModelRegistry.backend_type == self.backend_type)).all() - #for model_entry in existing_models: - # _heal_legacy_index_method(session, model_entry) - session.commit() - - for model_entry in existing_models: - dynamic_table = self._create_storage_table(engine=engine, entry=model_entry) - self.embedding_table_cache[model_entry.model_name] = dynamic_table - - def list_registered_models(self, session: Session) -> Sequence[EmbeddingModelRecord]: - """Return all embedding models known to this backend.""" - return tuple( - self._record_from_registry_row(row) - for row in session.scalars( - select(ModelRegistry).where(ModelRegistry.backend_type == self.backend_type) - ).all() + registered_models = self.embedding_model_registry.get_registered_models_from_db( + backend_type=self.backend_type ) + if registered_models is None: + return + + for model_record in registered_models: + self.register_model( + engine=engine, + model_name=model_record.model_name, + dimensions=model_record.dimensions, + index_type=model_record.index_type, + metadata=model_record.metadata, + ) + + def pre_initialise_store(self, engine: Engine) -> None: + """Hook for any setup steps that need to happen before the model registry schema is created. For example, this is used by the PGVector backend to create the vector extension before the registry tables are created.""" + pass def get_registered_model( self, - session: Session, model_name: str, + index_type: IndexType, ) -> Optional[EmbeddingModelRecord]: """Return metadata for one registered model, if present.""" - row = session.scalar( - select(ModelRegistry).where( - ModelRegistry.model_name == model_name, - ModelRegistry.backend_type == self.backend_type, - ) + registered_model = self.embedding_model_registry.get_registered_models_from_db( + backend_type=self.backend_type, + model_name=model_name, + index_type=index_type ) - if row is None: + if registered_model is None: return None - return self._record_from_registry_row(row) + + assert len(registered_model) == 1, ( + f"Expected exactly one registered model for name '{model_name}' and index type '{index_type.value}', but found {len(registered_model)}. This indicates a data integrity issue in the model registry database." + ) + return registered_model[0] - def is_model_registered(self, session: Session, model_name: str) -> bool: + + def is_model_registered(self, model_name: str, index_type: IndexType) -> bool: """Convenience wrapper over ``get_registered_model``.""" - return self.get_registered_model(session=session, model_name=model_name) is not None + return self.get_registered_model(model_name=model_name, index_type=index_type) is not None def get_concepts_without_embedding( self, session: Session, model_name: str, + index_type: IndexType, concept_filter: Optional[EmbeddingConceptFilter] = None, limit: Optional[int] = None, ) -> Mapping[int, str]: """Return concept IDs and names for concepts that do not have embeddings.""" query = self.q_get_concepts_without_embedding( model_name=model_name, + index_type=index_type, concept_filter=concept_filter, limit=limit ) @@ -160,10 +197,11 @@ def get_concepts_without_embedding( def q_get_concepts_without_embedding( self, model_name: str, + index_type: IndexType, concept_filter: Optional[EmbeddingConceptFilter] = None, limit: Optional[int] = None, ) -> Select: - embedding_table = self.get_embedding_table(model_name=model_name) + embedding_table = self.get_embedding_table(model_name=model_name, index_type=index_type) subq = select(1).where(embedding_table.concept_id == Concept.concept_id) query = ( @@ -180,9 +218,10 @@ def get_concepts_without_embedding_count( self, session: Session, model_name: str, + index_type: IndexType, concept_filter: Optional[EmbeddingConceptFilter] = None, ) -> int: - embedding_table = self.get_embedding_table(model_name=model_name) + embedding_table = self.get_embedding_table(model_name=model_name, index_type=index_type) subq = select(1).where(embedding_table.concept_id == Concept.concept_id) query = ( @@ -196,9 +235,6 @@ def get_concepts_without_embedding_count( return session.scalar(query) # type: ignore - @abstractmethod - def _create_storage_table(self, engine: Engine, entry: ModelRegistry) -> Type[T]: - """Backend-specific logic to create the dynamic SQLAlchemy class.""" def register_model( self, @@ -212,69 +248,26 @@ def register_model( """ Shared template method for model registration. """ - self.initialise_store(engine) - with Session(engine, expire_on_commit=False) as session: - existing_row = session.scalar( - select(ModelRegistry).where(ModelRegistry.model_name == model_name) - ) - if existing_row is not None: - if existing_row.backend_type != self.backend_type: - raise ModelRegistrationConflictError( - f"Model '{model_name}' is already registered for backend " - f"'{existing_row.backend_type}', not '{self.backend_type}'.", - conflict_field="backend_type" - ) - if existing_row.dimensions != dimensions: - raise ModelRegistrationConflictError( - f"Model '{model_name}' is already registered with dimensions " - f"{existing_row.dimensions}, not {dimensions}.", - conflict_field="dimensions" - ) - if existing_row.index_type != index_type: - raise ModelRegistrationConflictError( - f"Model '{model_name}' is already registered with " - f"index_method='{existing_row.index_type}', not " - f"'{index_type}'. Reuse the existing model " - "configuration or register a new model name.", - conflict_field="index_type" - ) - if existing_row.details != metadata: - raise ModelRegistrationConflictError( - f"Model '{model_name}' is already registered with different " - f"metadata. Reuse the existing model name or choose a new one.", - conflict_field="metadata" - ) - return self._record_from_registry_row(existing_row) - - ensure_model_registry_schema(engine) - safe_name = self.safe_model_name(model_name) - storage_name = self.storage_name(safe_model_name=safe_name) - - new_entry = ModelRegistry( + model_record = self.embedding_model_registry.register_model( model_name=model_name, dimensions=dimensions, - storage_identifier=storage_name, - index_type=index_type, backend_type=self.backend_type, - details=metadata + index_type=index_type, + metadata=metadata, ) - - with Session(engine, expire_on_commit=False) as session: - # Add logic here to check for existing records before adding - session.add(new_entry) - session.commit() - - dynamic_table = self._create_storage_table(engine, new_entry) - self.embedding_table_cache[model_name] = dynamic_table - return self._record_from_registry_row(new_entry) + # Local caching + self._cache_model_record(engine=engine, model_record=model_record) + return model_record @abstractmethod @require_registered_model def upsert_embeddings( self, - session: Session, model_name: str, + index_type: IndexType, model_record: EmbeddingModelRecord, + * + session: Session, concept_ids: Sequence[int], embeddings: ndarray, ) -> None: @@ -288,15 +281,17 @@ def upsert_embeddings( Parameters ---------- - session : sqlalchemy.orm.Session - The active database session used for transactional persistence and - model metadata updates. model_name : str The unique identifier or name of the embedding model (e.g., 'text-embedding-3-small'). + index_type : IndexType + The type of vector index used to store the embeddings. model_record : EmbeddingModelRecord A record object containing metadata, dimensions, and configuration specific to the embedding model being processed. + session : sqlalchemy.orm.Session + The active database session used for transactional persistence and + model metadata updates. concept_ids : Sequence[int] A sequence of OMOP standard concept IDs corresponding to the ordered rows in the embeddings array. @@ -319,9 +314,10 @@ def upsert_embeddings( @require_registered_model def get_embeddings_by_concept_ids( self, - session: Session, model_name: str, + index_type: IndexType, model_record: EmbeddingModelRecord, + session: Session, concept_ids: Sequence[int], ) -> Mapping[int, Sequence[float]]: """Fetch stored embedding vectors keyed by concept ID.""" @@ -331,11 +327,13 @@ def get_embeddings_by_concept_ids( @require_registered_model def get_nearest_concepts( self, - session: Session, model_name: str, + index_type: IndexType, + model_record: EmbeddingModelRecord, + * + session: Session, query_embedding: ndarray, metric_type: MetricType, - *, concept_filter: Optional[EmbeddingConceptFilter] = None, k: int = 10, ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: @@ -344,10 +342,15 @@ def get_nearest_concepts( Parameters ---------- - session : Session - SQLAlchemy session for any required relational access. model_name : str Name of the embedding model used to create the embeddings. + index_type : IndexType + The type of vector index used to store the embeddings. + model_record : EmbeddingModelRecord + A record object containing metadata, dimensions, and configuration + specific to the embedding model being queried. Obtained through the decorator's requirement for a registered model. + session : Session + SQLAlchemy session for any required relational access. query_embedding : ndarray The embedding vector to search with. Expected shape is (q, dimension) where q is the number of query vectors and dimension is the size of the embedding space for the model. @@ -363,25 +366,23 @@ def get_nearest_concepts( Tuple[Tuple[NearestConceptMatch, ...], ...] A tuple of tuples containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner tuple contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, k) where q is the number of query vectors and k is the number of nearest neighbors returned per query. """ - - def refresh_model_index(self, session: Session, model_name: str) -> None: - """ - Optional hook for derived index maintenance. - - PostgreSQL/pgvector backends may not need this. File-based or FAISS - backends may choose to rebuild or compact a search index here. - """ - return None - def has_any_embeddings(self, session: Session, model_name: str) -> bool: + def has_any_embeddings( + self, + session: Session, + model_name: str, + index_type: IndexType, + ) -> bool: embedding_table = self.get_embedding_table( model_name=model_name, + index_type=index_type, ) return session.query(embedding_table.concept_id).limit(1).first() is not None def get_embedding_table( self, model_name: str, + index_type: IndexType, ) -> Type[T]: """Return dynamically created ORM class for the embedding table associated with the given model name. This relies on the `initialise_store` method having been called to populate the `embedding_table_cache`. @@ -401,39 +402,12 @@ def get_embedding_table( ValueError If the model name is not found in the embedding table cache, indicating that the store was not properly initialized or the model has not been registered. """ - embedding_table = self.embedding_table_cache.get(model_name) + storage_key = (model_name, self.backend_type, index_type) + embedding_table = self.embedding_table_cache.get(storage_key) if embedding_table is not None: return embedding_table else: - raise ValueError(f"Model '{model_name}' not found in cache. Ensure the model is registered and the store is initialized.") - - @staticmethod - def _coerce_registry_metadata(value: object) -> Mapping[str, object]: - if isinstance(value, Mapping): - return dict(value) - return {} - - @staticmethod - def _record_from_registry_row( - row: ModelRegistry, - ) -> EmbeddingModelRecord: - return EmbeddingModelRecord( - model_name=row.model_name, - dimensions=row.dimensions, - backend_type=row.backend_type, - storage_identifier=row.storage_identifier, - index_type=row.index_type, - metadata=row.details, - ) - @staticmethod - def safe_model_name(model_name: str) -> str: - name = model_name.lower() - sanitized = re.sub(r"[^\w]+", "_", name) - sanitized = re.sub(r"_+", "_", sanitized).strip("_") - return sanitized - - def storage_name(self, safe_model_name: str) -> str: - return f"{self.backend_name.lower()}_{safe_model_name}" + raise ValueError(f"Embedding table for model '{model_name}' with index type '{index_type.value}' and backend '{self.backend_type.value}' not found in cache. Ensure that the store is initialized and the model is registered before attempting to access the embedding table.") @staticmethod def validate_embeddings(embeddings: ndarray, dimensions: int): diff --git a/src/omop_emb/backends/factory.py b/src/omop_emb/backends/factory.py index 4b921f9..e7fb670 100644 --- a/src/omop_emb/backends/factory.py +++ b/src/omop_emb/backends/factory.py @@ -39,8 +39,8 @@ def normalize_backend_name(backend_name: Optional[str]) -> BackendType: def get_embedding_backend( backend_name: Optional[str] = None, - *, - faiss_base_dir: Optional[str] = None, + storage_base_dir: Optional[str] = None, + registry_db_name: Optional[str] = None, ) -> EmbeddingBackend: """ Construct an embedding backend implementation by name. @@ -59,7 +59,7 @@ def get_embedding_backend( "PGVector embedding backend requested but its dependencies are not " "available. Install the package using `pip install omop-emb[pgvector]`." ) from exc - return PGVectorEmbeddingBackend() + return PGVectorEmbeddingBackend(storage_base_dir=storage_base_dir, registry_db_name=registry_db_name) if resolved == BackendType.FAISS: try: @@ -69,7 +69,7 @@ def get_embedding_backend( "FAISS embedding backend requested but its dependencies are not " "available. Install the package with the FAISS extra using `pip install omop-emb[faiss]` or omop-emb[faiss-gpu]." ) from exc - return FaissEmbeddingBackend(base_dir=faiss_base_dir) + return FaissEmbeddingBackend(storage_base_dir=storage_base_dir, registry_db_name=registry_db_name) raise EmbeddingBackendConfigurationError( f"Backend factory reached an unexpected state for backend={resolved!r}." diff --git a/src/omop_emb/backends/faiss/faiss_backend.py b/src/omop_emb/backends/faiss/faiss_backend.py index 3f05af2..6ce13fc 100644 --- a/src/omop_emb/backends/faiss/faiss_backend.py +++ b/src/omop_emb/backends/faiss/faiss_backend.py @@ -2,18 +2,13 @@ import os from pathlib import Path -from typing import Any, Mapping, Optional, Sequence, Type, cast, Dict, Tuple +from typing import Mapping, Optional, Sequence, Type, Dict, Tuple import numpy as np from numpy import ndarray -from sqlalchemy import Engine, insert, select +from sqlalchemy import Engine from sqlalchemy.orm import Session import logging -from dataclasses import dataclass, field -import h5py - -from omop_alchemy.cdm.model.vocabulary import Concept -from omop_emb.backends.registry import ModelRegistry from .faiss_sql import ( FAISSConceptIDEmbeddingRegistry, @@ -25,47 +20,14 @@ from .storage_manager import EmbeddingStorageManager from ..base import EmbeddingBackend, require_registered_model from omop_emb.utils.embedding_utils import ( - EmbeddingBackendCapabilities, EmbeddingConceptFilter, - EmbeddingModelRecord, NearestConceptMatch, get_similarity_from_distance ) +from omop_emb.model_registry import EmbeddingModelRecord logger = logging.getLogger(__name__) -@dataclass -class FaissMetadata: - """Strongly typed metadata required for the FAISS backend to operate.""" - index_file_path: str - metadata_bucket: dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> dict[str, Any]: - """Serializes to a flat dictionary for SQLAlchemy JSON storage.""" - # We flatten it so it looks clean in the database - result = self.metadata_bucket.copy() - result["index_file_path"] = self.index_file_path - return result - - @classmethod - def from_dict(cls, data: Mapping[str, Any]) -> "FaissMetadata": - """Safely reconstructs the object from the database JSON.""" - data_copy = dict(data) - - # Pop the required FAISS fields, raising a clear error if missing - try: - index_path = data_copy.pop("index_file_path") - except KeyError as e: - raise ValueError( - f"Corrupted FAISS metadata in registry. Missing required key: {e}" - ) - - # Everything else goes back into user_metadata - return cls( - index_file_path=str(index_path), - metadata_bucket=data_copy - ) - class FaissEmbeddingBackend(EmbeddingBackend[FAISSConceptIDEmbeddingRegistry]): """ @@ -89,42 +51,26 @@ class FaissEmbeddingBackend(EmbeddingBackend[FAISSConceptIDEmbeddingRegistry]): DEFAULT_FAISS_DIR = ".omop_emb/faiss" - def __init__(self, base_dir: Optional[str | os.PathLike[str]] = None): - super().__init__() - self.base_dir = Path( - base_dir - or os.getenv(ENV_OMOP_EMB_FAISS_INDEX_DIR) - or FaissEmbeddingBackend.DEFAULT_FAISS_DIR - ).expanduser() + def __init__( + self, + storage_base_dir: Optional[str | Path] = None, + registry_db_name: Optional[str] = None, + ): + super().__init__( + storage_base_dir=storage_base_dir, + registry_db_name=registry_db_name, + ) self.embedding_storage_managers: Dict[str, EmbeddingStorageManager] = {} @property def backend_type(self) -> BackendType: return BackendType.FAISS - @property - def capabilities(self) -> EmbeddingBackendCapabilities: - return EmbeddingBackendCapabilities( - stores_embeddings=True, - supports_incremental_upsert=True, - supports_nearest_neighbor_search=True, - supports_server_side_similarity=False, - supports_filter_by_concept_ids=True, - supports_filter_by_domain=True, - supports_filter_by_vocabulary=True, - supports_filter_by_standard_flag=True, - requires_explicit_index_refresh=False, - ) - - def initialise_store(self, engine) -> None: - self.base_dir.mkdir(parents=True, exist_ok=True) - return super().initialise_store(engine) - - def _create_storage_table(self, engine: Engine, entry: ModelRegistry) -> Type[FAISSConceptIDEmbeddingRegistry]: - return create_faiss_embedding_registry_table(engine=engine, model_registry_entry=entry) + def _create_storage_table(self, engine: Engine, model_record: EmbeddingModelRecord) -> Type[FAISSConceptIDEmbeddingRegistry]: + return create_faiss_embedding_registry_table(engine=engine, model_record=model_record) def get_safe_model_dir(self, model_name: str) -> Path: - return self.base_dir / self.safe_model_name(model_name) + return self.storage_base_dir / self.embedding_model_registry.safe_model_name(model_name) def get_storage_manager( self, @@ -166,31 +112,26 @@ def register_model( ) -> EmbeddingModelRecord: # Create the storage for the model on disk - safe_name = self.safe_model_name(model_name) - model_dir = self.base_dir / safe_name + safe_name = self.embedding_model_registry.safe_model_name(model_name) + model_dir = self.storage_base_dir / safe_name model_dir.mkdir(parents=False, exist_ok=True) logger.info(f"Created model directory: {model_dir}") - # NOTE: not sure if still necessary as the manager takes care of it all. - #faiss_metadata = FaissMetadata( - # index_file_path=str(self._index_path(model_name)), - # metadata_bucket=dict(metadata), - #) return super().register_model( engine=engine, model_name=model_name, dimensions=dimensions, index_type=index_type, - #metadata=faiss_metadata.to_dict(), metadata=metadata ) @require_registered_model def upsert_embeddings( self, - session: Session, model_name: str, + index_type: IndexType, model_record: EmbeddingModelRecord, + session: Session, concept_ids: Sequence[int], embeddings: ndarray, metric_type: Optional[MetricType] = None, @@ -204,15 +145,17 @@ def upsert_embeddings( Parameters ---------- - session : sqlalchemy.orm.Session - The active database session used for transactional persistence and - model metadata updates. model_name : str The unique identifier or name of the embedding model (e.g., 'text-embedding-3-small'). + index_type : IndexType + The type of vector index used to store the embeddings. model_record : EmbeddingModelRecord A record object containing metadata, dimensions, and configuration specific to the embedding model being processed. + session : sqlalchemy.orm.Session + The active database session used for transactional persistence and + model metadata updates. concept_ids : Sequence[int] A sequence of OMOP standard concept IDs corresponding to the ordered rows in the embeddings array. @@ -248,7 +191,10 @@ def upsert_embeddings( add_concept_ids_to_faiss_registry( concept_ids=concept_id_tuple, session=session, - registered_table=self.get_embedding_table(model_name), + registered_table=self.get_embedding_table( + model_name=model_name, + index_type=index_type, + ), ) except Exception as e: session.rollback() @@ -271,9 +217,10 @@ def upsert_embeddings( @require_registered_model def get_nearest_concepts( self, - session: Session, model_name: str, + index_type: IndexType, model_record: EmbeddingModelRecord, + session: Session, query_embeddings: np.ndarray, metric_type: MetricType, *, @@ -292,7 +239,7 @@ def get_nearest_concepts( self.validate_embeddings(embeddings=query_embeddings, dimensions=model_record.dimensions) q_permitted_concept_ids = q_concept_ids_with_embeddings( - embedding_table=self.get_embedding_table(model_name), + embedding_table=self.get_embedding_table(model_name=model_name, index_type=index_type), concept_filter=concept_filter, limit=None ) @@ -334,17 +281,14 @@ def get_nearest_concepts( )) matches.append(tuple(matches_per_query)) return tuple(matches) - - @require_registered_model - def refresh_model_index(self, session: Session, model_name: str, model_record: EmbeddingModelRecord) -> None: - raise NotImplementedError("Explicit index refresh is not required for the FAISS backend as it updates the index on each upsert. This method is a no-op.") @require_registered_model def get_embeddings_by_concept_ids( self, - session: Session, model_name: str, + index_type: IndexType, model_record: EmbeddingModelRecord, + session: Session, concept_ids: Sequence[int] ) -> Mapping[int, Sequence[float]]: concept_id_tuple = tuple(concept_ids) @@ -360,39 +304,3 @@ def get_embeddings_by_concept_ids( ) return storage_manager.get_embeddings_by_concept_ids(concept_ids=concept_ids_np) - - - def _get_filtered_concept_ids( - self, - session: Session, - concept_filter: EmbeddingConceptFilter, - ) -> tuple[int, ...]: - query = select(Concept.concept_id) - if concept_filter is not None: - query = concept_filter.apply(query) - return tuple(int(row.concept_id) for row in session.execute(query)) - - @staticmethod - def _normalize_matrix(matrix: np.ndarray) -> np.ndarray: - norms = np.linalg.norm(matrix, axis=1, keepdims=True) - norms[norms == 0] = 1.0 - return matrix / norms - - @staticmethod - def _filter_is_empty(concept_filter: EmbeddingConceptFilter) -> bool: - return ( - concept_filter.concept_ids is None - and concept_filter.domains is None - and concept_filter.vocabularies is None - and not concept_filter.require_standard - ) - - @staticmethod - def _create_faiss_index(*, faiss, index_type: str, dimensions: int): - index_type = str(getattr(index_type, "value", index_type)) - if index_type in {"IndexFlatIP", "flat", "flatip"}: - return faiss.IndexFlatIP(dimensions) - raise ValueError( - f"Unsupported FAISS index_type={index_type!r} in first-pass backend. " - "Currently supported: IndexFlatIP." - ) \ No newline at end of file diff --git a/src/omop_emb/backends/faiss/faiss_sql.py b/src/omop_emb/backends/faiss/faiss_sql.py index b99e050..a07e83e 100644 --- a/src/omop_emb/backends/faiss/faiss_sql.py +++ b/src/omop_emb/backends/faiss/faiss_sql.py @@ -9,9 +9,8 @@ from orm_loader.helpers import Base from omop_alchemy.cdm.model.vocabulary import Concept +from omop_emb.model_registry import EmbeddingModelRecord from ..base import ConceptIDEmbeddingBase -from omop_emb.config import BackendType -from ..registry import ModelRegistry from omop_emb.utils.embedding_utils import EmbeddingConceptFilter logger = logging.getLogger(__name__) @@ -26,13 +25,13 @@ class FAISSConceptIDEmbeddingRegistry(ConceptIDEmbeddingBase, Base): def create_faiss_embedding_registry_table( engine: Engine, - model_registry_entry: ModelRegistry, + model_record: EmbeddingModelRecord, ) -> Type[FAISSConceptIDEmbeddingRegistry]: """ Creates a dynamic SQLAlchemy ORM class that tracks which concept_ids are present in the FAISS/H5 storage for a specific model. """ - tablename = model_registry_entry.storage_identifier + tablename = model_record.storage_identifier cached_table = _FAISS_REGISTRY_TABLE_CACHE.get(tablename) if cached_table is not None: @@ -53,7 +52,7 @@ def create_faiss_embedding_registry_table( _FAISS_REGISTRY_TABLE_CACHE[tablename] = mapping_table logger.debug( - f"Initialized {FAISSConceptIDEmbeddingRegistry.__name__} table for model '{model_registry_entry.model_name}'", + f"Initialized {FAISSConceptIDEmbeddingRegistry.__name__} table for model '{model_record.model_name}'", ) return mapping_table diff --git a/src/omop_emb/backends/pgvector/pgvector_backend.py b/src/omop_emb/backends/pgvector/pgvector_backend.py index 394ed92..edff569 100644 --- a/src/omop_emb/backends/pgvector/pgvector_backend.py +++ b/src/omop_emb/backends/pgvector/pgvector_backend.py @@ -14,15 +14,13 @@ create_pg_embedding_table, add_embeddings_to_registered_table, ) -from ..registry import ModelRegistry -from omop_emb.config import BackendType, MetricType +from omop_emb.config import BackendType, MetricType, IndexType from ..base import EmbeddingBackend, require_registered_model from omop_emb.utils.embedding_utils import ( - EmbeddingBackendCapabilities, EmbeddingConceptFilter, - EmbeddingModelRecord, NearestConceptMatch, ) +from omop_emb.model_registry import EmbeddingModelRecord class PGVectorEmbeddingBackend(EmbeddingBackend[PGVectorConceptIDEmbeddingTable]): """ @@ -38,35 +36,21 @@ class PGVectorEmbeddingBackend(EmbeddingBackend[PGVectorConceptIDEmbeddingTable] def backend_type(self) -> BackendType: return BackendType.PGVECTOR - @property - def capabilities(self) -> EmbeddingBackendCapabilities: - return EmbeddingBackendCapabilities( - stores_embeddings=True, - supports_incremental_upsert=True, - supports_nearest_neighbor_search=True, - supports_server_side_similarity=True, - supports_filter_by_concept_ids=True, - supports_filter_by_domain=True, - supports_filter_by_vocabulary=True, - supports_filter_by_standard_flag=True, - requires_explicit_index_refresh=False, - ) + def _create_storage_table(self, engine: Engine, model_record: EmbeddingModelRecord) -> Type[PGVectorConceptIDEmbeddingTable]: + return create_pg_embedding_table(engine=engine, model_record=model_record) - def _create_storage_table(self, engine: Engine, entry: ModelRegistry) -> Type[PGVectorConceptIDEmbeddingTable]: - return create_pg_embedding_table(engine=engine, model_registry_entry=entry) - - def initialise_store(self, engine: Engine) -> None: + def pre_initialise_store(self, engine: Engine) -> None: with Session(engine, expire_on_commit=False) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vector CASCADE;")) session.commit() - return super().initialise_store(engine) @require_registered_model def upsert_embeddings( self, - session: Session, model_name: str, + index_type: IndexType, model_record: EmbeddingModelRecord, + session: Session, concept_ids: Sequence[int], embeddings: ndarray, ) -> None: @@ -80,6 +64,7 @@ def upsert_embeddings( table = self.get_embedding_table( model_name=model_name, + index_type=index_type, ) try: @@ -98,9 +83,10 @@ def upsert_embeddings( @require_registered_model def get_embeddings_by_concept_ids( self, - session: Session, model_name: str, + index_type: IndexType, model_record: EmbeddingModelRecord, + session: Session, concept_ids: Sequence[int], ) -> Mapping[int, Sequence[float]]: concept_id_tuple = tuple(concept_ids) @@ -109,6 +95,7 @@ def get_embeddings_by_concept_ids( embedding_table = self.get_embedding_table( model_name=model_name, + index_type=index_type, ) query = q_embedding_vectors_by_concept_ids( embedding_table=embedding_table, @@ -130,9 +117,10 @@ def get_embeddings_by_concept_ids( @require_registered_model def get_nearest_concepts( self, - session: Session, model_name: str, + index_type: IndexType, model_record: EmbeddingModelRecord, + session: Session, query_embeddings: ndarray, *, metric_type: MetricType, @@ -142,6 +130,7 @@ def get_nearest_concepts( embedding_table = self.get_embedding_table( model_name=model_name, + index_type=index_type, ) self.validate_embeddings(embeddings=query_embeddings, dimensions=model_record.dimensions) @@ -170,17 +159,3 @@ def get_nearest_concepts( return tuple(tuple(matches) for matches in results) - def refresh_model_index(self, session: Session, model_name: str) -> None: - # pgvector indexes update as rows are written, so no explicit refresh is - # required for the current PostgreSQL backend. - return None - - def _record_from_registry_row(self, row: ModelRegistry) -> EmbeddingModelRecord: - return EmbeddingModelRecord( - model_name=row.model_name, - dimensions=row.dimensions, - backend_type=row.backend_type, - storage_identifier=row.storage_identifier, - index_type=row.index_type, - metadata=self._coerce_registry_metadata(row.details), - ) diff --git a/src/omop_emb/backends/pgvector/pgvector_sql.py b/src/omop_emb/backends/pgvector/pgvector_sql.py index 371ebae..b89fb1e 100644 --- a/src/omop_emb/backends/pgvector/pgvector_sql.py +++ b/src/omop_emb/backends/pgvector/pgvector_sql.py @@ -12,7 +12,7 @@ from numpy import ndarray from omop_emb.config import BackendType, IndexType, MetricType -from ..registry import ModelRegistry +from omop_emb.model_registry import EmbeddingModelRecord from ..base import ConceptIDEmbeddingBase from omop_emb.utils.embedding_utils import EmbeddingConceptFilter, get_similarity_from_distance @@ -29,24 +29,23 @@ class PGVectorConceptIDEmbeddingTable(ConceptIDEmbeddingBase, Base): def create_pg_embedding_table( engine: Engine, - model_registry_entry: ModelRegistry, + model_record: EmbeddingModelRecord, ) -> Type[PGVectorConceptIDEmbeddingTable]: # Note the specific Type hint here - tablename = model_registry_entry.storage_identifier - dimensions = model_registry_entry.dimensions + tablename = model_record.storage_identifier + dimensions = model_record.dimensions cached_table = _PGVECTOR_TABLE_CACHE.get(tablename) if cached_table is not None: Base.metadata.create_all(engine, tables=[cached_table.__table__]) # type: ignore[arg-type] - create_indices_for_embedding_table(engine, model_registry_entry) + create_indices_for_embedding_table(engine, model_record) return cached_table table = type( f"PGVectorConceptIDEmbeddingTable_{tablename}", - (PGVectorConceptIDEmbeddingTable,), # It already inherits from Base and ConceptIDBase + (PGVectorConceptIDEmbeddingTable,), { "__tablename__": tablename, - # We override the attribute with the specific dimension-aware column "embedding": mapped_column(Vector(dimensions), nullable=False, index=False), "__table_args__": {"extend_existing": True}, "__module__": __name__, @@ -55,10 +54,10 @@ def create_pg_embedding_table( Base.metadata.create_all(engine, tables=[table.__table__]) # type: ignore[arg-type] _PGVECTOR_TABLE_CACHE[tablename] = table - create_indices_for_embedding_table(engine, model_registry_entry) + create_indices_for_embedding_table(engine, model_record) logger.debug( - f"Initialized {PGVectorConceptIDEmbeddingTable.__name__} table for model '{model_registry_entry.model_name}' with dimensions {model_registry_entry.dimensions} using index_method={model_registry_entry.index_type}", + f"Initialized {PGVectorConceptIDEmbeddingTable.__name__} table for model '{model_record.model_name}' with dimensions {model_record.dimensions} using index_method={model_record.index_type}", ) return table @@ -97,15 +96,15 @@ def add_embeddings_to_registered_table( session.commit() -def create_indices_for_embedding_table(engine: Engine, model_registry_entry: ModelRegistry) -> None: +def create_indices_for_embedding_table(engine: Engine, model_record: EmbeddingModelRecord) -> None: with engine.connect() as conn: inspector = inspect(conn) - existing_indexes = [idx['name'] for idx in inspector.get_indexes(model_registry_entry.storage_identifier)] + existing_indexes = [idx['name'] for idx in inspector.get_indexes(model_record.storage_identifier)] - if _index_from_storage_identifier(model_registry_entry.storage_identifier) not in existing_indexes: + if _index_from_storage_identifier(model_record.storage_identifier) not in existing_indexes: query = q_create_index_sql( - model_registry_entry.storage_identifier, - IndexType(model_registry_entry.index_type), + model_record.storage_identifier, + model_record.index_type, ) if query is not None: conn.execute(text(query)) diff --git a/src/omop_emb/cli.py b/src/omop_emb/cli.py index 49b7c8e..5d042b1 100644 --- a/src/omop_emb/cli.py +++ b/src/omop_emb/cli.py @@ -93,14 +93,6 @@ def add_embeddings( interface.initialise_store(engine) with Session(engine) as reader, Session(engine) as writer: - interface.ensure_model_registered( - engine=engine, - session=reader, - model_name=model, - index_type=index_type, - dimensions=embedding_dim - ) - total_concepts = num_embeddings or interface.get_concepts_without_embedding_count( session=reader, model_name=model, diff --git a/src/omop_emb/config.py b/src/omop_emb/config.py index 5e8418f..2ae83de 100644 --- a/src/omop_emb/config.py +++ b/src/omop_emb/config.py @@ -3,6 +3,7 @@ ENV_OMOP_EMB_BACKEND = "OMOP_EMB_BACKEND" ENV_OMOP_EMB_FAISS_INDEX_DIR = "OMOP_EMB_FAISS_INDEX_DIR" +ENV_BASE_STORAGE_DIR = "OMOP_EMB_BASE_STORAGE_DIR" # For backends that use file-based storage, e.g. FAISS with on-disk indices or registry metadata storage class BackendType(StrEnum): diff --git a/src/omop_emb/interface.py b/src/omop_emb/interface.py index e62645d..1bba90e 100644 --- a/src/omop_emb/interface.py +++ b/src/omop_emb/interface.py @@ -14,10 +14,8 @@ EmbeddingBackend, get_embedding_backend, ) -from omop_emb.utils.embedding_utils import ( - EmbeddingConceptFilter, - EmbeddingModelRecord, -) +from omop_emb.utils.embedding_utils import EmbeddingConceptFilter +from omop_emb.model_registry import EmbeddingModelRecord from .config import IndexType, MetricType @@ -50,13 +48,14 @@ def from_backend_name( cls, embedding_client: Optional[LLMClient] = None, backend_name: Optional[str] = None, - *, - faiss_base_dir: Optional[str] = None, + storage_base_dir: Optional[str] = None, + registry_db_name: Optional[str] = None, ) -> EmbeddingInterface: return cls( backend=get_embedding_backend( backend_name=backend_name, - faiss_base_dir=faiss_base_dir, + storage_base_dir=storage_base_dir, + registry_db_name=registry_db_name, ), embedding_client=embedding_client, ) @@ -66,8 +65,8 @@ def initialise_store(self, engine: Engine) -> None: def get_model_table_name( self, - session: Session, model_name: str, + index_type: IndexType, ) -> Optional[str]: """ Legacy helper preserved for compatibility. @@ -75,31 +74,49 @@ def get_model_table_name( Historically this returned a PostgreSQL table name. Under the backend abstraction it returns the backend-specific storage identifier. """ - record = self.backend.get_registered_model(session=session, model_name=model_name) + record = self.backend.get_registered_model(model_name=model_name, index_type=index_type) return record.storage_identifier if record is not None else None def is_model_registered( self, - session: Session, model_name: str, + index_type: IndexType, ) -> bool: - return self.backend.is_model_registered(session=session, model_name=model_name) + return self.backend.is_model_registered(model_name=model_name, index_type=index_type) + def register_model( + self, + engine: Engine, + model_name: str, + dimensions: int, + index_type: IndexType, + metadata: Mapping[str, object] = {}, + ) -> EmbeddingModelRecord: + return self.backend.register_model( + engine=engine, + model_name=model_name, + dimensions=dimensions, + index_type=index_type, + metadata=metadata, + ) def has_any_embeddings( self, session: Session, embedding_model_name: str, + index_type: IndexType, ) -> bool: return self.backend.has_any_embeddings( session=session, model_name=embedding_model_name, + index_type=index_type, ) def get_nearest_concepts( self, session: Session, model_name: str, + index_type: IndexType, query_embedding: np.ndarray, *, metric_type: MetricType | str, @@ -116,6 +133,8 @@ def get_nearest_concepts( SQLAlchemy session for any required relational access. model_name : str Name of the embedding model used to create the embeddings. + index_type : IndexType + The type of vector index used to store the embeddings. query_embedding : ndarray The embedding vector to search with. Expected shape is (q, dimension) where q is the number of query vectors and dimension is the size of the embedding space for the model. @@ -141,6 +160,7 @@ def get_nearest_concepts( nearest_concepts = self.backend.get_nearest_concepts( session=session, model_name=model_name, + index_type=index_type, query_embedding=query_embedding, concept_filter=concept_filter, metric_type=metric_type, @@ -152,6 +172,7 @@ def get_nearest_concepts_by_texts( self, session: Session, embedding_model_name: str, + index_type: IndexType, query_texts: str | Tuple[str, ...] | List[str], *, metric_type: MetricType, @@ -169,6 +190,8 @@ def get_nearest_concepts_by_texts( SQLAlchemy session for any required relational access. embedding_model_name : str Name of the embedding model used to create the embeddings. + index_type : IndexType + The type of vector index used to store the embeddings. query_texts : str | Tuple[str, ...] | List[str] The text(s) to embed and search with. If a single string is provided, it will be embedded and searched as one query. If a tuple or list of strings is provided, each string will be embedded and searched separately, and the results will be returned in the same order as the input texts. metric_type : MetricType @@ -195,6 +218,7 @@ def get_nearest_concepts_by_texts( return self.get_nearest_concepts( session=session, model_name=embedding_model_name, + index_type=index_type, query_embedding=query_embeddings, metric_type=metric_type, concept_filter=concept_filter, @@ -205,11 +229,13 @@ def get_embeddings_by_concept_ids( self, session: Session, embedding_model_name: str, + index_type: IndexType, concept_ids: Tuple[int, ...], ) -> Mapping[int, Sequence[float]]: return self.backend.get_embeddings_by_concept_ids( session=session, model_name=embedding_model_name, + index_type=index_type, concept_ids=concept_ids, ) @@ -223,6 +249,7 @@ def initialise_tables(self, engine: Engine): def add_to_db( self, session: Session, + index_type: IndexType, concept_ids: Tuple[int, ...], embeddings: ndarray, model: str, @@ -236,38 +263,10 @@ def add_to_db( return self.backend.upsert_embeddings( session=session, model_name=model, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) - - def ensure_model_registered( - self, - *, - engine: Engine, - session: Session, - model_name: str, - dimensions: int, - index_type: IndexType, - metadata: Mapping[str, object] = {}, - ) -> EmbeddingModelRecord: - """ - Ensure the embedding model exists in the selected backend. - """ - - existing = self.backend.get_registered_model( - session=session, - model_name=model_name, - ) - if existing is not None: - return existing - else: - return self.backend.register_model( - engine=engine, - model_name=model_name, - dimensions=dimensions, - index_type=index_type, - metadata=metadata, - ) def embed_texts( self, @@ -286,12 +285,14 @@ def upsert_concept_embeddings( *, session: Session, model_name: str, + index_type: IndexType, concept_ids: Sequence[int], embeddings: ndarray, ) -> None: self.backend.upsert_embeddings( session=session, model_name=model_name, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -301,6 +302,7 @@ def embed_and_upsert_concepts( *, session: Session, model_name: str, + index_type: IndexType, concept_ids: Sequence[int], concept_texts: Sequence[str], embedding_client: Optional[LLMClient] = None, @@ -319,6 +321,7 @@ def embed_and_upsert_concepts( self.upsert_concept_embeddings( session=session, model_name=model_name, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -329,12 +332,14 @@ def get_concepts_without_embedding( *, session: Session, model_name: str, + index_type: IndexType, concept_filter: Optional[EmbeddingConceptFilter] = None, limit: Optional[int] = None, ) -> Mapping[int, str]: return self.backend.get_concepts_without_embedding( session=session, model_name=model_name, + index_type=index_type, concept_filter=concept_filter, limit=limit, ) @@ -343,11 +348,13 @@ def q_get_concepts_without_embedding( self, *, model_name: str, + index_type: IndexType, concept_filter: Optional[EmbeddingConceptFilter] = None, limit: Optional[int] = None, ) -> Select: return self.backend.q_get_concepts_without_embedding( model_name=model_name, + index_type=index_type, concept_filter=concept_filter, limit=limit, ) @@ -357,10 +364,12 @@ def get_concepts_without_embedding_count( *, session: Session, model_name: str, + index_type: IndexType, concept_filter: Optional[EmbeddingConceptFilter] = None, ) -> int: return self.backend.get_concepts_without_embedding_count( session=session, model_name=model_name, + index_type=index_type, concept_filter=concept_filter, ) diff --git a/src/omop_emb/model_registry/__init__.py b/src/omop_emb/model_registry/__init__.py new file mode 100644 index 0000000..47c7feb --- /dev/null +++ b/src/omop_emb/model_registry/__init__.py @@ -0,0 +1,2 @@ +from .model_registry_types import EmbeddingModelRecord +from .model_registry_manager import ModelRegistryManager \ No newline at end of file diff --git a/src/omop_emb/backends/registry.py b/src/omop_emb/model_registry/model_registry_cdm.py similarity index 74% rename from src/omop_emb/backends/registry.py rename to src/omop_emb/model_registry/model_registry_cdm.py index 9b16e23..5dda099 100644 --- a/src/omop_emb/backends/registry.py +++ b/src/omop_emb/model_registry/model_registry_cdm.py @@ -1,9 +1,7 @@ from __future__ import annotations from sqlalchemy import DateTime, Engine, Integer, JSON, String, func, inspect, text, Enum -from sqlalchemy.orm import mapped_column, validates - -from orm_loader.helpers import Base +from sqlalchemy.orm import DeclarativeBase, mapped_column, validates from ..config import ( is_index_type_supported_for_backend, @@ -12,7 +10,11 @@ BackendType ) -class ModelRegistry(Base): +class ModelRegistryBase(DeclarativeBase): + """Dedicated declarative base for local model registry metadata.""" + + +class ModelRegistry(ModelRegistryBase): """ Shared database-backed registry for embedding models across backends. @@ -26,33 +28,34 @@ class ModelRegistry(Base): __tablename__ = "model_registry" - # TODO: Think about having multiple models per index and backend. Would require to - # have the storage_identifier (i.e. the table_name) to be unique as well, which is - # currently not enforced explicitly but implicitily as it only depends on model_name that is unique model_name = mapped_column(String, primary_key=True) + backend_type = mapped_column(Enum(BackendType, native_enum=False), nullable=False, primary_key=True) + index_type = mapped_column(Enum(IndexType, native_enum=False), nullable=False, primary_key=True) dimensions = mapped_column(Integer, nullable=False) storage_identifier = mapped_column("table_name", String, unique=True, nullable=False) - index_type = mapped_column(Enum(IndexType, native_enum=False), nullable=False) - backend_type = mapped_column(Enum(BackendType, native_enum=False), nullable=False) details = mapped_column(JSON, nullable=True, default=dict) created_at = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) updated_at = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()) + + @validates("backend_type") + def validate_backend_type(self, key, backend_type): + if backend_type not in BackendType: + raise ValueError(f"Unsupported backend type: {backend_type}. Supported backends: {list(BackendType)}") + return backend_type + @validates("index_type") def validate_index_for_backend(self, key, index_type): - if is_index_type_supported_for_backend(self.backend_type, index_type): + if self.backend_type is None: + raise ValueError("Instantiate the registry entry with a valid backend_type before setting index_type.") + + if not is_index_type_supported_for_backend(self.backend_type, index_type): raise ValueError( f"Backend {self.backend_type} does not support {index_type}. " f"Supported: {get_supported_index_types_for_backend(self.backend_type)}" ) return index_type - - @validates("backend_type") - def validate_backend_type(self, key, backend_type): - if backend_type not in BackendType: - raise ValueError(f"Unsupported backend type: {backend_type}. Supported backends: {list(BackendType)}") - return backend_type def ensure_model_registry_schema(engine: Engine) -> None: - Base.metadata.create_all(engine, tables=[ModelRegistry.__table__]) # type: ignore[arg-type] + ModelRegistryBase.metadata.create_all(engine, tables=[ModelRegistry.__table__]) # type: ignore[arg-type] diff --git a/src/omop_emb/model_registry/model_registry_manager.py b/src/omop_emb/model_registry/model_registry_manager.py new file mode 100644 index 0000000..df9c851 --- /dev/null +++ b/src/omop_emb/model_registry/model_registry_manager.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from sqlalchemy import Engine, create_engine, and_, select +from sqlalchemy.orm import Session +import logging +from typing import Optional, Mapping +import pathlib +import re + +from omop_emb.config import ( + IndexType, + BackendType +) +from .model_registry_cdm import ModelRegistry, ensure_model_registry_schema +from .model_registry_types import EmbeddingModelRecord +from omop_emb.utils.errors import ModelRegistrationConflictError + +logger = logging.getLogger(__name__) + + + +class ModelRegistryManager: + """Manages model registry (metadata) for embedding models locally in a separate SQLite database.""" + DB_FILENAME = "metadata.db" + def __init__( + self, + base_dir: str | pathlib.Path, + db_file: Optional[str] = None, + ): + db_file = db_file or self.DB_FILENAME + self.db_path = pathlib.Path(base_dir) / db_file + self.engine = create_engine(f"sqlite:///{self.db_path}") + + ensure_model_registry_schema(self.engine) + + def get_registered_models_from_db( + self, + backend_type: Optional[BackendType] = None, + model_name: Optional[str] = None, + index_type: Optional[IndexType] = None, + ) -> Optional[tuple[EmbeddingModelRecord, ...]]: + """ + Get the registered embedding models of the database, optionally filtered by backend type, model name, or index type. + """ + stmt = select(ModelRegistry) + if backend_type is not None: + stmt = stmt.where(ModelRegistry.backend_type == backend_type) + if model_name is not None: + stmt = stmt.where(ModelRegistry.model_name == model_name) + if index_type is not None: + stmt = stmt.where(ModelRegistry.index_type == index_type) + + with Session(self.engine, expire_on_commit=False) as session: + existing_models = session.scalars(stmt).all() + + if not existing_models: + return None + + return tuple(self._registry_entry_to_model_record(match) for match in existing_models) + + + def register_model( + self, + model_name: str, + dimensions: int, + *, + backend_type: BackendType, + index_type: IndexType, + metadata: Mapping[str, object] = {}, + ) -> EmbeddingModelRecord: + """ + Shared template method for model registration. + """ + with Session(self.engine, expire_on_commit=False) as session: + existing_row = session.scalar( + select(ModelRegistry).where( + and_( + ModelRegistry.model_name == model_name, + ModelRegistry.backend_type == backend_type, + ModelRegistry.index_type == index_type, + ) + ) + ) + if existing_row is not None: + if existing_row.backend_type != backend_type: + raise ModelRegistrationConflictError( + f"Model '{model_name}' is already registered with backend " + f"'{existing_row.backend_type}', not '{backend_type}'. " + "Reuse the existing model name or choose a new one.", + conflict_field="backend_type" + ) + + if existing_row.dimensions != dimensions: + raise ModelRegistrationConflictError( + f"Model '{model_name}' is already registered with dimensions " + f"{existing_row.dimensions}, not {dimensions}.", + conflict_field="dimensions" + ) + if existing_row.index_type != index_type: + raise ModelRegistrationConflictError( + f"Model '{model_name}' is already registered with " + f"index_method='{existing_row.index_type}', not " + f"'{index_type}'. Reuse the existing model " + "configuration or register a new model name.", + conflict_field="index_type" + ) + if existing_row.details != metadata: + raise ModelRegistrationConflictError( + f"Model '{model_name}' is already registered with different " + f"metadata. Reuse the existing model name or choose a new one.", + conflict_field="metadata" + ) + return self._registry_entry_to_model_record(existing_row) + + safe_name = self.safe_model_name(model_name) + storage_name = self.storage_name( + safe_model_name=safe_name, + index_type=index_type, + backend_type=backend_type + ) + + new_entry = ModelRegistry( + model_name=model_name, + backend_type=backend_type, + index_type=index_type, + dimensions=dimensions, + storage_identifier=storage_name, + details=metadata + ) + + with Session(self.engine, expire_on_commit=False) as session: + session.add(new_entry) + session.commit() + return self._registry_entry_to_model_record(new_entry) + + + @staticmethod + def safe_model_name(model_name: str) -> str: + name = model_name.lower() + sanitized = re.sub(r"[^\w]+", "_", name) + sanitized = re.sub(r"_+", "_", sanitized).strip("_") + return sanitized + + @staticmethod + def storage_name( + safe_model_name: str, + index_type: IndexType, + backend_type: BackendType + ) -> str: + return f"{backend_type.value.lower()}_{safe_model_name}_{index_type.value}" + + @staticmethod + def _coerce_registry_metadata(value: object) -> Mapping[str, object]: + if isinstance(value, Mapping): + return dict(value) + return {} + + @staticmethod + def _registry_entry_to_model_record(entry: ModelRegistry) -> EmbeddingModelRecord: + return EmbeddingModelRecord( + model_name=entry.model_name, + dimensions=entry.dimensions, + backend_type=entry.backend_type, + index_type=entry.index_type, + storage_identifier=entry.storage_identifier, + metadata=entry.details, + ) diff --git a/src/omop_emb/model_registry/model_registry_types.py b/src/omop_emb/model_registry/model_registry_types.py new file mode 100644 index 0000000..4cc4acd --- /dev/null +++ b/src/omop_emb/model_registry/model_registry_types.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Mapping + +from omop_emb.config import ( + IndexType, + BackendType +) + +@dataclass(frozen=True) +class EmbeddingModelRecord: + """ + Canonical description of a registered embedding model. + + ``storage_identifier`` is intentionally backend-specific. For example: + - PostgreSQL backend: dynamic embedding table name + - FAISS backend: on-disk index path or logical collection name + """ + + model_name: str + dimensions: int + backend_type: BackendType + index_type: IndexType + storage_identifier: str + metadata: Mapping[str, object] = field(default_factory=dict) diff --git a/src/omop_emb/utils/embedding_utils.py b/src/omop_emb/utils/embedding_utils.py index 8cdb503..202d26a 100644 --- a/src/omop_emb/utils/embedding_utils.py +++ b/src/omop_emb/utils/embedding_utils.py @@ -7,24 +7,6 @@ from ..config import BackendType, IndexType, MetricType -@dataclass(frozen=True) -class EmbeddingModelRecord: - """ - Canonical description of a registered embedding model. - - ``storage_identifier`` is intentionally backend-specific. For example: - - PostgreSQL backend: dynamic embedding table name - - FAISS backend: on-disk index path or logical collection name - """ - - model_name: str - dimensions: int - backend_type: BackendType - index_type: IndexType - storage_identifier: Optional[str] = None - metadata: Mapping[str, object] = field(default_factory=dict) - - @dataclass(frozen=True) class EmbeddingConceptFilter: """ @@ -72,27 +54,6 @@ class NearestConceptMatch: is_active: bool -@dataclass(frozen=True) -class EmbeddingBackendCapabilities: - """ - Capability flags for a backend implementation. - - These are not used by the current code yet, but they make backend - differences explicit. For example, a FAISS backend might support nearest - neighbor search but require explicit refreshes after bulk writes. - """ - - stores_embeddings: bool = True - supports_incremental_upsert: bool = True - supports_nearest_neighbor_search: bool = True - supports_server_side_similarity: bool = True - supports_filter_by_concept_ids: bool = True - supports_filter_by_domain: bool = True - supports_filter_by_vocabulary: bool = True - supports_filter_by_standard_flag: bool = True - requires_explicit_index_refresh: bool = False - - def get_similarity_from_distance(distance_col, metric: MetricType): """ Helper to map various distance metrics to a 0.0 - 1.0 similarity score. diff --git a/tests/conftest.py b/tests/conftest.py index 8ec92d8..54c27f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,6 @@ import pytest import numpy as np import sqlalchemy as sa -from sqlalchemy.engine import make_url from sqlalchemy.orm import Session, sessionmaker @@ -23,7 +22,6 @@ from omop_emb.backends.faiss import FaissEmbeddingBackend from omop_emb.backends.pgvector import PGVectorEmbeddingBackend -from omop_emb.backends.registry import ensure_model_registry_schema, ModelRegistry from omop_emb.interface import EmbeddingInterface @@ -89,8 +87,6 @@ def _create_test_database() -> sa.URL: conn.execute(sa.text(f"DROP ROLE IF EXISTS {TEST_USERNAME}")) conn.execute(sa.text(f"CREATE USER {TEST_USERNAME} WITH PASSWORD '{TEST_PASSWORD}' SUPERUSER")) conn.execute(sa.text(f'CREATE DATABASE "{TEST_DB_NAME}" OWNER {TEST_USERNAME}')) - conn.execute(sa.text(f"DROP TABLE IF EXISTS {ModelRegistry.__tablename__} CASCADE;")) - finally: admin_engine.dispose() @@ -134,7 +130,6 @@ def pg_engine(): # Create all tables Base.metadata.create_all(engine) - ensure_model_registry_schema(engine) yield engine @@ -157,11 +152,8 @@ def session(pg_engine) -> Generator[Session, None, None]: # Clear tables before test with pg_engine.connect() as conn: with conn.begin(): - res = conn.execute(sa.text(f"SELECT table_name FROM {ModelRegistry.__tablename__}")) - for (table_name,) in res: - conn.execute(sa.text(f"DROP TABLE IF EXISTS {table_name} CASCADE")) - conn.execute(sa.text(f"TRUNCATE TABLE concept, {ModelRegistry.__tablename__} CASCADE")) - + conn.execute(sa.text(f"TRUNCATE TABLE concept CASCADE")) + Session = sessionmaker(bind=pg_engine, future=True) db_session = Session() @@ -193,24 +185,24 @@ def create_embeddings(concept_names: list[str] | str, batch_size: Optional[int] @pytest.fixture -def temp_faiss_dir(): +def temp_storage_dir(): """Temporary directory for FAISS indices.""" with tempfile.TemporaryDirectory() as tmpdir: yield Path(tmpdir) @pytest.fixture -def faiss_backend(session, temp_faiss_dir) -> FaissEmbeddingBackend: +def faiss_backend(session, temp_storage_dir) -> FaissEmbeddingBackend: """FAISS backend with model registry initialized.""" - backend = FaissEmbeddingBackend(base_dir=temp_faiss_dir) + backend = FaissEmbeddingBackend(storage_base_dir=temp_storage_dir) backend.initialise_store(session.bind) return backend @pytest.fixture -def pgvector_backend(session) -> PGVectorEmbeddingBackend: +def pgvector_backend(session, temp_storage_dir) -> PGVectorEmbeddingBackend: """PGVector backend with vector extension and model registry initialized.""" - backend = PGVectorEmbeddingBackend() + backend = PGVectorEmbeddingBackend(storage_base_dir=temp_storage_dir) backend.initialise_store(session.bind) return backend diff --git a/tests/shared_backend_tests.py b/tests/shared_backend_tests.py index c0fb70d..0a9dca4 100644 --- a/tests/shared_backend_tests.py +++ b/tests/shared_backend_tests.py @@ -12,39 +12,39 @@ class SharedBackendTests: """Base class containing backend-parity tests via a generic ``backend`` fixture.""" - def test_backend_registration(self, session, backend): + def test_backend_registration(self, session, backend, index_type: IndexType = IndexType.FLAT): """Test registering model with the backend.""" model = backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) assert model.model_name == MODEL_NAME assert model.dimensions == EMBEDDING_DIM - def test_get_registered_model(self, session, backend): + def test_get_registered_model(self, session, backend, index_type: IndexType = IndexType.FLAT): """Test retrieving registered model.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) - retrieved = backend.get_registered_model(session, MODEL_NAME) + retrieved = backend.get_registered_model(model_name=MODEL_NAME, index_type=index_type) assert retrieved is not None assert retrieved.model_name == MODEL_NAME - def test_duplicate_registration_identical_success(self, session, backend): + def test_duplicate_registration_identical_success(self, session, backend, index_type: IndexType = IndexType.FLAT): """Test that idempotent model registration returns the existing record.""" params = { "engine": session.bind, "model_name": MODEL_NAME, "dimensions": EMBEDDING_DIM, - "index_type": IndexType.FLAT, + "index_type": index_type, "metadata": {"version": "1.0"}, } @@ -54,13 +54,13 @@ def test_duplicate_registration_identical_success(self, session, backend): assert model1.storage_identifier == model2.storage_identifier assert model2.metadata["version"] == "1.0" - def test_registration_metadata_conflict_raises_error(self, session, backend): + def test_registration_metadata_conflict_raises_error(self, session, backend, index_type: IndexType = IndexType.FLAT): """Test conflicting metadata on re-registration raises conflict error.""" params = { "engine": session.bind, "model_name": MODEL_NAME, "dimensions": EMBEDDING_DIM, - "index_type": IndexType.FLAT, + "index_type": index_type, "metadata": {"version": "1.0"}, } backend.register_model(**params) @@ -73,13 +73,13 @@ def test_registration_metadata_conflict_raises_error(self, session, backend): assert excinfo.value.conflict_field == "metadata" - def test_upsert_embeddings(self, session, backend, mock_llm_client): + def test_upsert_embeddings(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test upserting embeddings.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concept = CONCEPTS["Hypertension"] @@ -89,19 +89,20 @@ def test_upsert_embeddings(self, session, backend, mock_llm_client): backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) - assert backend.has_any_embeddings(session, MODEL_NAME) + assert backend.has_any_embeddings(session=session, model_name=MODEL_NAME, index_type=index_type) - def test_get_embeddings_by_ids(self, session, backend, mock_llm_client): + def test_get_embeddings_by_ids(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test retrieving embeddings by concept IDs.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concepts = [CONCEPTS["Hypertension"], CONCEPTS["Diabetes"]] @@ -111,6 +112,7 @@ def test_get_embeddings_by_ids(self, session, backend, mock_llm_client): backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -119,18 +121,19 @@ def test_get_embeddings_by_ids(self, session, backend, mock_llm_client): session=session, model_name=MODEL_NAME, concept_ids=concept_ids, + index_type=index_type, ) assert len(retrieved) == 2 assert set(retrieved.keys()) == set(concept_ids) - def test_nearest_neighbor_search(self, session, backend, mock_llm_client): + def test_nearest_neighbor_search(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test exact top-1 nearest-neighbor retrieval with deterministic embeddings.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concepts = list(CONCEPTS.values()) @@ -140,6 +143,7 @@ def test_nearest_neighbor_search(self, session, backend, mock_llm_client): backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -148,6 +152,7 @@ def test_nearest_neighbor_search(self, session, backend, mock_llm_client): results = backend.get_nearest_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, query_embeddings=query_embeddings, metric_type=MetricType.L2, k=1, @@ -157,13 +162,13 @@ def test_nearest_neighbor_search(self, session, backend, mock_llm_client): assert len(results[0]) == 1 assert results[0][0].concept_id == CONCEPTS["Hypertension"].concept_id - def test_nearest_neighbor_with_domain_filter(self, session, backend, mock_llm_client): + def test_nearest_neighbor_with_domain_filter(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test nearest neighbor search with domain filter.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concepts = list(CONCEPTS.values()) @@ -173,6 +178,7 @@ def test_nearest_neighbor_with_domain_filter(self, session, backend, mock_llm_cl backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -180,6 +186,7 @@ def test_nearest_neighbor_with_domain_filter(self, session, backend, mock_llm_cl results = backend.get_nearest_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, query_embeddings=TEST_CONCEPT_EMB, concept_filter=EmbeddingConceptFilter(domains=("Condition",)), metric_type=MetricType.L2, @@ -190,13 +197,13 @@ def test_nearest_neighbor_with_domain_filter(self, session, backend, mock_llm_cl assert len(results) == 1 assert set(m.concept_id for m in results[0]) == expected_ids - def test_nearest_neighbor_with_vocabulary_filter(self, session, backend, mock_llm_client): + def test_nearest_neighbor_with_vocabulary_filter(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test nearest neighbor search with vocabulary filter.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concepts = list(CONCEPTS.values()) @@ -206,6 +213,7 @@ def test_nearest_neighbor_with_vocabulary_filter(self, session, backend, mock_ll backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -213,6 +221,7 @@ def test_nearest_neighbor_with_vocabulary_filter(self, session, backend, mock_ll results = backend.get_nearest_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, query_embeddings=TEST_CONCEPT_EMB, concept_filter=EmbeddingConceptFilter(vocabularies=("RxNorm",)), metric_type=MetricType.L2, @@ -223,13 +232,13 @@ def test_nearest_neighbor_with_vocabulary_filter(self, session, backend, mock_ll assert len(results[0]) == 1 assert results[0][0].concept_id == CONCEPTS["Aspirin"].concept_id - def test_get_concepts_without_embedding(self, session, backend, mock_llm_client): + def test_get_concepts_without_embedding(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test retrieving concepts without embeddings.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concept = CONCEPTS["Hypertension"] @@ -239,6 +248,7 @@ def test_get_concepts_without_embedding(self, session, backend, mock_llm_client) backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -246,13 +256,14 @@ def test_get_concepts_without_embedding(self, session, backend, mock_llm_client) unembedded = backend.get_concepts_without_embedding( session=session, model_name=MODEL_NAME, + index_type=index_type, ) assert CONCEPTS["Aspirin"].concept_id in unembedded assert CONCEPTS["Diabetes"].concept_id in unembedded assert CONCEPTS["Hypertension"].concept_id not in unembedded - def test_l2_similarity_exact_values(self, session, backend, mock_llm_client): + def test_l2_similarity_exact_values(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test L2 distance calculations yield expected similarity scores. With deterministic embeddings: @@ -264,8 +275,8 @@ def test_l2_similarity_exact_values(self, session, backend, mock_llm_client): backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concepts = list(CONCEPTS.values()) @@ -275,6 +286,7 @@ def test_l2_similarity_exact_values(self, session, backend, mock_llm_client): backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -282,6 +294,7 @@ def test_l2_similarity_exact_values(self, session, backend, mock_llm_client): results = backend.get_nearest_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, query_embeddings=TEST_CONCEPT_EMB, metric_type=MetricType.L2, k=10, @@ -301,7 +314,7 @@ def test_l2_similarity_exact_values(self, session, backend, mock_llm_client): rtol=1e-5, ) - def test_cosine_similarity_exact_values(self, session, backend, mock_llm_client): + def test_cosine_similarity_exact_values(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test cosine distance calculations yield expected similarity scores. With deterministic embeddings (1D vectors): @@ -316,8 +329,8 @@ def test_cosine_similarity_exact_values(self, session, backend, mock_llm_client) backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concepts = list(CONCEPTS.values()) @@ -327,6 +340,7 @@ def test_cosine_similarity_exact_values(self, session, backend, mock_llm_client) backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -334,6 +348,7 @@ def test_cosine_similarity_exact_values(self, session, backend, mock_llm_client) results = backend.get_nearest_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, query_embeddings=TEST_CONCEPT_EMB, metric_type=MetricType.COSINE, k=10, diff --git a/tests/test_config.py b/tests/test_config.py index 8e247ea..eadcb91 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -17,7 +17,7 @@ def test_normalize_backend_name(self): def test_get_faiss_backend(self, tmp_path): """Test getting FAISS backend.""" - backend = get_embedding_backend("faiss", faiss_base_dir=str(tmp_path)) + backend = get_embedding_backend("faiss", storage_base_dir=str(tmp_path)) assert backend.backend_type == BackendType.FAISS def test_faiss_supports_flat_index(self): diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index d5af158..b4bfac4 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -63,10 +63,6 @@ def test_mock_llm_client_generates_embeddings(self, mock_llm_client): # Should have correct shape assert emb1.shape == (1, mock_llm_client.embedding_dim) - def test_faiss_backend_initializes(self, faiss_backend, session): - """Verify FAISS backend initializes without error.""" - assert faiss_backend is not None - assert faiss_backend.base_dir.exists() def test_embedding_interface_initializes(self, embedding_interface, session): """Verify EmbeddingInterface initializes with mocks.""" diff --git a/tests/test_interface.py b/tests/test_interface.py index c38916a..978a651 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -17,9 +17,8 @@ class TestInterface: def test_register_model(self, session, embedding_interface: EmbeddingInterface): """Test registering an embedding model.""" - model = embedding_interface.ensure_model_registered( + model = embedding_interface.register_model( engine=session.bind, - session=session, model_name=MODEL_NAME, dimensions=EMBEDDING_DIM, index_type=IndexType.FLAT, @@ -28,39 +27,36 @@ def test_register_model(self, session, embedding_interface: EmbeddingInterface): assert model.model_name == MODEL_NAME assert model.dimensions == EMBEDDING_DIM - def test_model_registration_idempotent(self, session, embedding_interface: EmbeddingInterface): + def test_model_registration_idempotent(self, session, embedding_interface: EmbeddingInterface, index_type: IndexType = IndexType.FLAT): """Test registering same model twice returns existing.""" - m1 = embedding_interface.ensure_model_registered( + m1 = embedding_interface.register_model( engine=session.bind, - session=session, model_name=MODEL_NAME, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, + index_type=index_type, ) - m2 = embedding_interface.ensure_model_registered( + m2 = embedding_interface.register_model( engine=session.bind, - session=session, model_name=MODEL_NAME, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, + index_type=index_type, ) assert m1.storage_identifier == m2.storage_identifier - def test_is_model_registered(self, session, embedding_interface: EmbeddingInterface): + def test_is_model_registered(self, session, embedding_interface: EmbeddingInterface, index_type: IndexType = IndexType.FLAT): """Test checking model registration status.""" - assert not embedding_interface.is_model_registered(session, MODEL_NAME) + assert not embedding_interface.is_model_registered(model_name=MODEL_NAME, index_type=index_type) - embedding_interface.ensure_model_registered( + embedding_interface.register_model( engine=session.bind, - session=session, model_name=MODEL_NAME, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, + index_type=index_type, ) - assert embedding_interface.is_model_registered(session, MODEL_NAME) + assert embedding_interface.is_model_registered(model_name=MODEL_NAME, index_type=index_type) def test_embed_texts(self, embedding_interface: EmbeddingInterface): """Test embedding generation.""" @@ -76,14 +72,13 @@ def test_embed_multiple_texts(self, embedding_interface: EmbeddingInterface): assert embeddings.shape == (len(texts), EMBEDDING_DIM) - def test_embed_and_upsert(self, session, embedding_interface: EmbeddingInterface): + def test_embed_and_upsert(self, session, embedding_interface: EmbeddingInterface, index_type: IndexType = IndexType.FLAT): """Test embedding and upserting concepts.""" - embedding_interface.ensure_model_registered( + embedding_interface.register_model( engine=session.bind, - session=session, model_name=MODEL_NAME, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, + index_type=index_type, ) test_concepts = [CONCEPTS["Hypertension"], CONCEPTS["Diabetes"]] @@ -95,19 +90,19 @@ def test_embed_and_upsert(self, session, embedding_interface: EmbeddingInterface model_name=MODEL_NAME, concept_ids=concept_ids, concept_texts=concept_texts, + index_type=index_type, ) assert embeddings.shape == (2, EMBEDDING_DIM) - assert embedding_interface.has_any_embeddings(session, MODEL_NAME) + assert embedding_interface.has_any_embeddings(session, embedding_model_name=MODEL_NAME, index_type=index_type) - def test_get_embeddings_by_concept_ids(self, session, embedding_interface: EmbeddingInterface): + def test_get_embeddings_by_concept_ids(self, session, embedding_interface: EmbeddingInterface, index_type: IndexType = IndexType.FLAT): """Test retrieving stored embeddings.""" - embedding_interface.ensure_model_registered( + embedding_interface.register_model( engine=session.bind, - session=session, model_name=MODEL_NAME, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, + index_type=index_type, ) test_concepts = [CONCEPTS["Hypertension"], CONCEPTS["Diabetes"]] @@ -117,6 +112,7 @@ def test_get_embeddings_by_concept_ids(self, session, embedding_interface: Embed embeddings = embedding_interface.embed_and_upsert_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, concept_texts=concept_texts, ) @@ -125,13 +121,14 @@ def test_get_embeddings_by_concept_ids(self, session, embedding_interface: Embed session=session, embedding_model_name=MODEL_NAME, concept_ids=concept_ids, + index_type=index_type, ) assert len(retrieved) == 2 assert 1 in retrieved and 2 in retrieved @pytest.mark.parametrize("backend_name", ["faiss", "pgvector"]) - def test_search_return_structure_for_backends(self, session, mock_llm_client, backend_name): + def test_search_return_structure_for_backends(self, session, mock_llm_client, backend_name, index_type: IndexType = IndexType.FLAT): """Search returns tuple[dict[int, float], ...] for both backend modes.""" mock_backend = Mock() mock_backend.get_nearest_concepts.return_value = ( @@ -162,6 +159,7 @@ def test_search_return_structure_for_backends(self, session, mock_llm_client, ba result = interface.get_nearest_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, query_embedding=query_embedding, metric_type=MetricType.COSINE, k=2, From 82a72d1b1117af4e6bd91ca5bc0aff596df9b35a Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Wed, 8 Apr 2026 04:57:27 +0000 Subject: [PATCH 06/19] Have model_record as private attrib and populate from decorator --- src/omop_emb/backends/base.py | 28 ++++---- src/omop_emb/backends/faiss/faiss_backend.py | 69 +++++++------------ .../backends/pgvector/pgvector_backend.py | 18 ++--- .../model_registry/model_registry_manager.py | 4 +- 4 files changed, 50 insertions(+), 69 deletions(-) diff --git a/src/omop_emb/backends/base.py b/src/omop_emb/backends/base.py index 2061b64..a9b5e1d 100644 --- a/src/omop_emb/backends/base.py +++ b/src/omop_emb/backends/base.py @@ -31,9 +31,10 @@ def wrapper( if record is None: raise ValueError( - f"Embedding model '{model_name}' is not registered in the FAISS backend." + f"Embedding model '{model_name}' with index_type='{index_type.value}' is not registered in backend '{self.backend_type.value}'." ) - return func(self, model_name, index_type, record, *args, **kwargs) + kwargs["_model_record"] = record + return func(self, model_name, index_type, *args, **kwargs) return wrapper @@ -141,13 +142,7 @@ def initialise_store(self, engine: Engine) -> None: return for model_record in registered_models: - self.register_model( - engine=engine, - model_name=model_record.model_name, - dimensions=model_record.dimensions, - index_type=model_record.index_type, - metadata=model_record.metadata, - ) + self._cache_model_record(engine=engine, model_record=model_record) def pre_initialise_store(self, engine: Engine) -> None: """Hook for any setup steps that need to happen before the model registry schema is created. For example, this is used by the PGVector backend to create the vector extension before the registry tables are created.""" @@ -243,7 +238,7 @@ def register_model( dimensions: int, *, index_type: IndexType, - metadata: Mapping[str, object] = {}, + metadata: Optional[Mapping[str, object]] = None, ) -> EmbeddingModelRecord: """ Shared template method for model registration. @@ -253,7 +248,7 @@ def register_model( dimensions=dimensions, backend_type=self.backend_type, index_type=index_type, - metadata=metadata, + metadata=metadata or {}, ) # Local caching self._cache_model_record(engine=engine, model_record=model_record) @@ -265,11 +260,11 @@ def upsert_embeddings( self, model_name: str, index_type: IndexType, - model_record: EmbeddingModelRecord, - * + *, session: Session, concept_ids: Sequence[int], embeddings: ndarray, + _model_record: EmbeddingModelRecord, ) -> None: """ Insert or update vector embeddings for a collection of OMOP concept IDs. @@ -316,9 +311,10 @@ def get_embeddings_by_concept_ids( self, model_name: str, index_type: IndexType, - model_record: EmbeddingModelRecord, + *, session: Session, concept_ids: Sequence[int], + _model_record: EmbeddingModelRecord, ) -> Mapping[int, Sequence[float]]: """Fetch stored embedding vectors keyed by concept ID.""" @@ -329,13 +325,13 @@ def get_nearest_concepts( self, model_name: str, index_type: IndexType, - model_record: EmbeddingModelRecord, - * + *, session: Session, query_embedding: ndarray, metric_type: MetricType, concept_filter: Optional[EmbeddingConceptFilter] = None, k: int = 10, + _model_record: EmbeddingModelRecord, ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: """ Return nearest stored concepts for the query embedding. diff --git a/src/omop_emb/backends/faiss/faiss_backend.py b/src/omop_emb/backends/faiss/faiss_backend.py index 6ce13fc..ca756bc 100644 --- a/src/omop_emb/backends/faiss/faiss_backend.py +++ b/src/omop_emb/backends/faiss/faiss_backend.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os from pathlib import Path from typing import Mapping, Optional, Sequence, Type, Dict, Tuple @@ -16,7 +15,7 @@ add_concept_ids_to_faiss_registry, q_concept_ids_with_embeddings ) -from omop_emb.config import BackendType, IndexType, MetricType, ENV_OMOP_EMB_FAISS_INDEX_DIR +from omop_emb.config import BackendType, IndexType, MetricType from .storage_manager import EmbeddingStorageManager from ..base import EmbeddingBackend, require_registered_model from omop_emb.utils.embedding_utils import ( @@ -74,30 +73,24 @@ def get_safe_model_dir(self, model_name: str) -> Path: def get_storage_manager( self, - model_name: str, - dimensions: int, - index_type: IndexType, + model_record: EmbeddingModelRecord, ) -> EmbeddingStorageManager: self.register_storage_manager( - model_name=model_name, - dimensions=dimensions, - index_type=index_type, + model_record=model_record, ) - return self.embedding_storage_managers[model_name] + return self.embedding_storage_managers[model_record.model_name] def register_storage_manager( self, - model_name: str, - dimensions: int, - index_type: IndexType, + model_record: EmbeddingModelRecord, ) -> None: """Registers a storage manager for the given model if not already registered. This ensures that the necessary on-disk structures are in place for the model's embeddings and index.""" - if model_name not in self.embedding_storage_managers: - logger.info(f"Registering new storage manager for model '{model_name}' with dimensions={dimensions}, index_type={index_type}") - self.embedding_storage_managers[model_name] = EmbeddingStorageManager( - file_dir=self.get_safe_model_dir(model_name), - dimensions=dimensions, + if model_record.model_name not in self.embedding_storage_managers: + logger.info(f"Registering new storage manager for model '{model_record.model_name}' with dimensions={model_record.dimensions}, index_type={model_record.index_type}") + self.embedding_storage_managers[model_record.model_name] = EmbeddingStorageManager( + file_dir=self.get_safe_model_dir(model_record.model_name), + dimensions=model_record.dimensions, backend_type=self.backend_type, ) @@ -130,10 +123,10 @@ def upsert_embeddings( self, model_name: str, index_type: IndexType, - model_record: EmbeddingModelRecord, session: Session, concept_ids: Sequence[int], embeddings: ndarray, + _model_record: EmbeddingModelRecord, metric_type: Optional[MetricType] = None, ) -> None: """ @@ -183,7 +176,7 @@ def upsert_embeddings( self.validate_embeddings_and_concept_ids( concept_ids=concept_id_tuple, embeddings=embeddings, - dimensions=model_record.dimensions, + dimensions=_model_record.dimensions, ) # Try to add first before we damage the on-disk stuff @@ -202,15 +195,11 @@ def upsert_embeddings( f"Failed to add concept IDs to FAISS registry for model '{model_name}'. This may be due to a mismatch between the provided concept_ids and the existing entries in the database, or a database constraint violation. Original error: {str(e)}" ) from e - storage_manager = self.get_storage_manager( - model_name=model_name, - dimensions=model_record.dimensions, - index_type=model_record.index_type, - ) + storage_manager = self.get_storage_manager(_model_record) storage_manager.append( concept_ids=np.array(concept_id_tuple, dtype=np.int64), embeddings=embeddings, - index_type=model_record.index_type, + index_type=_model_record.index_type, metric_type=metric_type ) @@ -219,25 +208,21 @@ def get_nearest_concepts( self, model_name: str, index_type: IndexType, - model_record: EmbeddingModelRecord, session: Session, query_embeddings: np.ndarray, metric_type: MetricType, *, concept_filter: Optional[EmbeddingConceptFilter] = None, k: int = 10, + _model_record: EmbeddingModelRecord, ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: - storage_manager = self.get_storage_manager( - model_name=model_name, - dimensions=model_record.dimensions, - index_type=model_record.index_type, - ) + storage_manager = self.get_storage_manager(_model_record) if storage_manager.get_count() == 0: return () - self.validate_embeddings(embeddings=query_embeddings, dimensions=model_record.dimensions) + self.validate_embeddings(embeddings=query_embeddings, dimensions=_model_record.dimensions) q_permitted_concept_ids = q_concept_ids_with_embeddings( embedding_table=self.get_embedding_table(model_name=model_name, index_type=index_type), concept_filter=concept_filter, @@ -257,7 +242,7 @@ def get_nearest_concepts( distances, concept_ids = storage_manager.search( query_vector=query_embeddings, metric_type=metric_type, - index_type=model_record.index_type, + index_type=_model_record.index_type, k=k, subset_concept_ids=permitted_concept_ids ) @@ -284,12 +269,12 @@ def get_nearest_concepts( @require_registered_model def get_embeddings_by_concept_ids( - self, - model_name: str, - index_type: IndexType, - model_record: EmbeddingModelRecord, - session: Session, - concept_ids: Sequence[int] + self, + model_name: str, + index_type: IndexType, + session: Session, + concept_ids: Sequence[int], + _model_record: EmbeddingModelRecord ) -> Mapping[int, Sequence[float]]: concept_id_tuple = tuple(concept_ids) if not concept_id_tuple: @@ -297,10 +282,6 @@ def get_embeddings_by_concept_ids( concept_ids_np = np.array(concept_id_tuple, dtype=np.int64) - storage_manager = self.get_storage_manager( - model_name=model_name, - dimensions=model_record.dimensions, - index_type=model_record.index_type, - ) + storage_manager = self.get_storage_manager(_model_record) return storage_manager.get_embeddings_by_concept_ids(concept_ids=concept_ids_np) diff --git a/src/omop_emb/backends/pgvector/pgvector_backend.py b/src/omop_emb/backends/pgvector/pgvector_backend.py index edff569..c76271d 100644 --- a/src/omop_emb/backends/pgvector/pgvector_backend.py +++ b/src/omop_emb/backends/pgvector/pgvector_backend.py @@ -49,17 +49,18 @@ def upsert_embeddings( self, model_name: str, index_type: IndexType, - model_record: EmbeddingModelRecord, + *, session: Session, concept_ids: Sequence[int], embeddings: ndarray, + _model_record: EmbeddingModelRecord, ) -> None: concept_id_tuple = tuple(concept_ids) self.validate_embeddings_and_concept_ids( concept_ids=concept_id_tuple, embeddings=embeddings, - dimensions=model_record.dimensions, + dimensions=_model_record.dimensions, ) table = self.get_embedding_table( @@ -85,10 +86,12 @@ def get_embeddings_by_concept_ids( self, model_name: str, index_type: IndexType, - model_record: EmbeddingModelRecord, + *, session: Session, concept_ids: Sequence[int], + _model_record: EmbeddingModelRecord, ) -> Mapping[int, Sequence[float]]: + concept_id_tuple = tuple(concept_ids) if not concept_id_tuple: return {} @@ -119,20 +122,19 @@ def get_nearest_concepts( self, model_name: str, index_type: IndexType, - model_record: EmbeddingModelRecord, + *, session: Session, query_embeddings: ndarray, - *, metric_type: MetricType, concept_filter: Optional[EmbeddingConceptFilter] = None, k: int = 10, - ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: - + _model_record: EmbeddingModelRecord, + ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: embedding_table = self.get_embedding_table( model_name=model_name, index_type=index_type, ) - self.validate_embeddings(embeddings=query_embeddings, dimensions=model_record.dimensions) + self.validate_embeddings(embeddings=query_embeddings, dimensions=_model_record.dimensions) query_list = query_embeddings.tolist() query = q_embedding_nearest_concepts( diff --git a/src/omop_emb/model_registry/model_registry_manager.py b/src/omop_emb/model_registry/model_registry_manager.py index df9c851..18bd10b 100644 --- a/src/omop_emb/model_registry/model_registry_manager.py +++ b/src/omop_emb/model_registry/model_registry_manager.py @@ -66,7 +66,7 @@ def register_model( *, backend_type: BackendType, index_type: IndexType, - metadata: Mapping[str, object] = {}, + metadata: Optional[Mapping[str, object]] = None, ) -> EmbeddingModelRecord: """ Shared template method for model registration. @@ -112,6 +112,8 @@ def register_model( ) return self._registry_entry_to_model_record(existing_row) + metadata = metadata or {} + safe_name = self.safe_model_name(model_name) storage_name = self.storage_name( safe_model_name=safe_name, From 9b684f0398bf91ef62309a6fea90858d6bbd379b Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Wed, 8 Apr 2026 05:13:07 +0000 Subject: [PATCH 07/19] Unified docstring, adapted documentation --- README.md | 22 +++++- docs/index.md | 20 +++-- docs/usage/backend-selection.md | 16 ++-- docs/usage/cli.md | 17 +++- docs/usage/installation.md | 10 ++- mkdocs.yml | 1 + src/omop_emb/backends/base.py | 77 +++++++++++++------ src/omop_emb/backends/faiss/faiss_backend.py | 24 ++---- .../backends/pgvector/pgvector_backend.py | 56 ++++++++++++++ src/omop_emb/config.py | 1 - 10 files changed, 178 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index 949bbcd..42a7d2a 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,41 @@ # omop-emb Embedding layer for OMOP CDM. +`omop-emb` now separates model metadata from embedding storage: + +- model metadata is stored locally in SQLite (`metadata.db`) +- embedding vectors are stored by the selected backend (`pgvector` or `faiss`) +- OMOP concept metadata remains in the OMOP CDM database + ## Installation `omop-emb` now exposes backend-specific optional dependencies so installation can match the embedding backend you actually intend to use. ```bash -pip install "omop-emb[postgres]" +pip install "omop-emb[pgvector]" pip install "omop-emb[faiss]" pip install "omop-emb[all]" ``` Notes: -- `postgres` installs the PostgreSQL/pgvector dependencies. +- `pgvector` installs the PostgreSQL/pgvector dependencies. - `faiss` installs the FAISS-based backend dependencies. This currently only includes CPU support - `all` installs both backend stacks for development or mixed environments. - A plain `pip install omop-emb` installs the shared core package only. -- PostgreSQL-specific embedding dependencies are now optional, but `omop-emb` - still requires some database backend for OMOP access and model registration. +- PostgreSQL-specific embedding dependencies are optional, but `omop-emb` + still requires OMOP CDM database access. - Non-PostgreSQL database backends have not yet been tested. +## Runtime Configuration + +Common environment variables: + +- `OMOP_EMB_BACKEND`: backend name (`pgvector` or `faiss`) used by the backend factory. +- `OMOP_EMB_BASE_STORAGE_DIR`: local base directory for `omop-emb` artifacts, including local metadata (`metadata.db`) and FAISS files. +- `OMOP_DATABASE_URL`: SQLAlchemy URL for the OMOP CDM database. + Extended documentation can be found [here](https://AustralianCancerDataNetwork.github.io/omop-emb). # Project Roadmap diff --git a/docs/index.md b/docs/index.md index c62c5af..fc36fbd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,11 +5,12 @@ The package currently supports: - dynamic embedding model registration - - multiple embedding models can be stored in the respective backend + - model metadata is stored locally in SQLite (`metadata.db`) + - multiple embedding models can be tracked per backend and index type - embedding and lookup for OMOP concepts -- supports various backends with a PostgreSQL linker +- supports various storage backends - [pgvector](https://github.com/pgvector/pgvector): storage in the original OMOP database - - [FAISS](https://github.com/facebookresearch/faiss): efficient storage on disk for low-RAM applications + - [FAISS](https://github.com/facebookresearch/faiss): on-disk vector storage and index files - Extension to [`omop-alchemy`](https://AustralianCancerDataNetwork.github.io/OMOP_Alchemy/) to support new tables - CLI scripts to add embeddings to an already existing OMOP CDM @@ -18,7 +19,7 @@ The package currently supports: Install the backend you actually want to use: ```bash -pip install "omop-emb[postgres]" +pip install "omop-emb[pgvector]" pip install "omop-emb[faiss]" pip install "omop-emb[all]" ``` @@ -28,12 +29,19 @@ A plain `pip install omop-emb` installs only the shared core package. At runtime, backend choice should also be explicit. The intended direction is: - install-time choice via extras -- runtime choice via config such as `OMOP_EMB_BACKEND=postgres` or `OMOP_EMB_BACKEND=faiss` or passing it as an argument to the respective interface (e.g. see [CLI reference](usage/cli.md)) +- runtime choice via config such as `OMOP_EMB_BACKEND=pgvector` or `OMOP_EMB_BACKEND=faiss` or passing it as an argument to the respective interface (e.g. see [CLI reference](usage/cli.md)) + +Recommended runtime environment variables: + +- `OMOP_EMB_BACKEND` (`pgvector` or `faiss`) +- `OMOP_EMB_BASE_STORAGE_DIR` (base directory for local metadata and FAISS artifacts) +- `OMOP_DATABASE_URL` (OMOP CDM database URL) !!! info "Important caveats" - - `omop-emb` depends on an OMOP PostgreSQL database for storage of embeddings (pgvector) or to keep track of already embedded concepts. + - `omop-emb` depends on OMOP CDM database access for concept metadata and filtering. + - Current operational and test coverage is PostgreSQL-focused. ## Documentation overview diff --git a/docs/usage/backend-selection.md b/docs/usage/backend-selection.md index 29aa9aa..ca5daff 100644 --- a/docs/usage/backend-selection.md +++ b/docs/usage/backend-selection.md @@ -11,7 +11,8 @@ The current backend factory recognizes: - `pgvector`: The [pgvector](https://github.com/pgvector/pgvector) extension to a standard postgres database to store embeddings directly in the database. - `faiss`: The [FAISS](https://github.com/facebookresearch/faiss) storage solution for on-disk storage. -The default backend name is currently `postgres`. +There is no implicit default backend name. You must pass one explicitly or set +`OMOP_EMB_BACKEND`. ## Runtime selection @@ -23,8 +24,9 @@ The intended pattern is: Examples: ```bash -export OMOP_EMB_BACKEND=postgres +export OMOP_EMB_BACKEND=pgvector export OMOP_EMB_BACKEND=faiss +export OMOP_EMB_BASE_STORAGE_DIR=$HOME/.omop_emb ``` You can also pass the backend name directly in Python. @@ -36,7 +38,7 @@ The backend factory lives in `omop_emb.backends`: ```python from omop_emb.backends import get_embedding_backend -backend = get_embedding_backend("postgres") +backend = get_embedding_backend("pgvector") backend = get_embedding_backend("faiss") ``` @@ -78,10 +80,10 @@ At the moment: - the backend abstraction and backend factory exist - PostgreSQL and FAISS backend classes exist - the production CLI path still targets the PostgreSQL embedding workflow -- PostgreSQL-specific embedding dependencies are optional, but a database - backend is still required for OMOP access and model registration -- model registration is intended to remain shared and database-backed even when - FAISS is used for vector storage and retrieval +- PostgreSQL-specific embedding dependencies are optional, but OMOP database + access is still required for concept metadata +- model registration metadata is stored locally in SQLite (`metadata.db`) under + `OMOP_EMB_BASE_STORAGE_DIR` - database backends other than PostgreSQL have not yet been tested So this page documents the selection model and Python interface shape now, even diff --git a/docs/usage/cli.md b/docs/usage/cli.md index 2697cbc..1e5c71f 100644 --- a/docs/usage/cli.md +++ b/docs/usage/cli.md @@ -11,14 +11,17 @@ At present, the production CLI path is PostgreSQL-oriented and stores embeddings ## Prerequisites -- **Installation**: install the PostgreSQL backend dependencies: +- **Installation**: install the backend dependencies you plan to use: ```bash - pip install "omop-emb[postgres]" + pip install "omop-emb[pgvector]" + # or + pip install "omop-emb[faiss]" ``` - **Database**: Postgres implementation of OMOP CDM. See [`omop-graph` documentation](reference-missing) for information how to setup. -- **Environment**: `OMOP_DATABASE_URL` must be exported or existing in the .env file (e.g., `postgresql://user:pass@localhost:5432/omop`). +- **Environment**: `OMOP_DATABASE_URL` must be exported or present in `.env` (e.g., `postgresql://user:pass@localhost:5432/omop`). +- **Backend config**: set `OMOP_EMB_BACKEND` (`pgvector` or `faiss`) and optionally `OMOP_EMB_BASE_STORAGE_DIR`. - **Connectivity**: Access to an OpenAI-compatible embeddings endpoint. *Currently only Ollama supported*. !!! note "Backend Scope" @@ -48,8 +51,14 @@ where `[OPTIONS]` are optional arguments that can be specified as described belo | **`--index-type`** | | `IndexType` | `FLAT` | The storage index for the embeddings for retrieval. Currently supported: `FLAT`. | | **`--batch-size`** | `-b` | `Integer` | `100` | Number of concepts to process in each chunk. | | **`--model`** | `-m` | `String` | `text-embedding-3-small` | Name of the embedding model to use for generating vectors. | -| **`--backend`** | | `Literal['pgvector', 'faiss']` | `None` | Embedding backend to use (can be replaced by `OMOP_EMB_BACKEND` env var). Requires the respective backend installed using `pip install omop-emb[pgvector or faiss]` | +| **`--backend`** | | `Literal['pgvector', 'faiss']` | `None` | Embedding backend to use (can be replaced by `OMOP_EMB_BACKEND`). Requires the corresponding optional dependency. | | **`--faiss-base-dir`** | | `String` | `None` | Optional base directory for FAISS backend storage. | | **`--standard-only`** | | `Boolean` | `False` | If set, only generate embeddings for OMOP standard concepts (`standard_concept = 'S'`). | | **`--vocabulary`** | | `List[String]` | `None` | Filter to embed concepts only from specific OMOP vocabularies. | | **`--num-embeddings`** | `-n` | `Integer` | `None` | Limit the number of concepts processed (useful for testing). | + +## Environment Variables + +- `OMOP_DATABASE_URL`: OMOP CDM database connection string. +- `OMOP_EMB_BACKEND`: backend selector used when `--backend` is not provided. +- `OMOP_EMB_BASE_STORAGE_DIR`: local storage root for metadata and file-based artifacts. diff --git a/docs/usage/installation.md b/docs/usage/installation.md index 80cd1cd..07d7186 100644 --- a/docs/usage/installation.md +++ b/docs/usage/installation.md @@ -18,7 +18,7 @@ requires database-backed OMOP access and model registration. ## PostgreSQL backend ```bash -pip install "omop-emb[postgres]" +pip install "omop-emb[pgvector]" ``` Use this when you want the current pgvector/PostgreSQL-backed embedding store @@ -33,7 +33,7 @@ pip install "omop-emb[faiss]" Use this when you want the FAISS backend dependencies available. Even in this case, a database backend is still required for OMOP concept -metadata and model registration. +metadata access. ## Everything @@ -52,12 +52,16 @@ install-time. Examples: ```bash -export OMOP_EMB_BACKEND=postgres +export OMOP_EMB_BACKEND=pgvector export OMOP_EMB_BACKEND=faiss +export OMOP_EMB_BASE_STORAGE_DIR=$HOME/.omop_emb ``` That avoids silent fallback between backend implementations. +`OMOP_EMB_BASE_STORAGE_DIR` controls where `omop-emb` stores local metadata +(`metadata.db`) and file-based backend artifacts (such as FAISS files). + ## Current database support caveat PostgreSQL-specific embedding dependencies are now optional, but the broader diff --git a/mkdocs.yml b/mkdocs.yml index 25cd776..ddfdb98 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -31,6 +31,7 @@ nav: - Home: index.md - Getting Started: - Installation: usage/installation.md + - Backend Selection: usage/backend-selection.md - "CLI Reference": usage/cli.md plugins: diff --git a/src/omop_emb/backends/base.py b/src/omop_emb/backends/base.py index a9b5e1d..d8b418d 100644 --- a/src/omop_emb/backends/base.py +++ b/src/omop_emb/backends/base.py @@ -126,12 +126,26 @@ def _cache_model_record( @abstractmethod def _create_storage_table(self, engine: Engine, model_record: EmbeddingModelRecord) -> Type[T]: - """Backend-specific logic to create the dynamic SQLAlchemy class.""" + """ + Backend-specific logic to create the dynamic SQLAlchemy class. + + Parameters + ---------- + engine : Engine + SQLAlchemy engine for OMOP CDM database storage. + model_record : EmbeddingModelRecord + Registered model metadata used to build backend storage. + """ def initialise_store(self, engine: Engine) -> None: """ Initialise the model registry and populate the embedding table cache for existing models. Can extend it to include any backend-specific setup steps (e.g., creating extensions, staging directories) as needed. + + Parameters + ---------- + engine : Engine + SQLAlchemy engine for OMOP CDM database storage. """ self.pre_initialise_store(engine) @@ -145,7 +159,14 @@ def initialise_store(self, engine: Engine) -> None: self._cache_model_record(engine=engine, model_record=model_record) def pre_initialise_store(self, engine: Engine) -> None: - """Hook for any setup steps that need to happen before the model registry schema is created. For example, this is used by the PGVector backend to create the vector extension before the registry tables are created.""" + """ + Hook for setup steps that run before store initialization. + + Parameters + ---------- + engine : Engine + SQLAlchemy engine for OMOP CDM database storage. + """ pass def get_registered_model( @@ -242,6 +263,19 @@ def register_model( ) -> EmbeddingModelRecord: """ Shared template method for model registration. + + Parameters + ---------- + engine : Engine + SQLAlchemy engine for OMOP CDM database storage. + model_name : str + Registered name of the embedding model. + dimensions : int + Embedding dimensionality $D$ for this model. + index_type : IndexType + Storage index type used for this model's embeddings. + metadata : Optional[Mapping[str, object]] + Optional metadata persisted with the model registration. """ model_record = self.embedding_model_registry.register_model( model_name=model_name, @@ -277,22 +311,17 @@ def upsert_embeddings( Parameters ---------- model_name : str - The unique identifier or name of the embedding model (e.g., - 'text-embedding-3-small'). + Registered name of the embedding model. index_type : IndexType - The type of vector index used to store the embeddings. - model_record : EmbeddingModelRecord - A record object containing metadata, dimensions, and configuration - specific to the embedding model being processed. + Storage index type used for this model's embeddings. session : sqlalchemy.orm.Session - The active database session used for transactional persistence and - model metadata updates. + SQLAlchemy session bound to the OMOP CDM database. concept_ids : Sequence[int] - A sequence of OMOP standard concept IDs corresponding to the - ordered rows in the embeddings array. + Concept IDs aligned with the rows of ``embeddings``. embeddings : numpy.ndarray - A 2D array of shape (n_concepts, n_dimensions) containing the - generated vector representations. + Embedding matrix of shape ``(n_concepts, D)``. + _model_record : EmbeddingModelRecord + Internal registered-model record injected by ``@require_registered_model``. Returns ------- @@ -339,23 +368,21 @@ def get_nearest_concepts( Parameters ---------- model_name : str - Name of the embedding model used to create the embeddings. + Registered name of the embedding model. index_type : IndexType - The type of vector index used to store the embeddings. - model_record : EmbeddingModelRecord - A record object containing metadata, dimensions, and configuration - specific to the embedding model being queried. Obtained through the decorator's requirement for a registered model. + Storage index type used for this model's embeddings. session : Session - SQLAlchemy session for any required relational access. + SQLAlchemy session bound to the OMOP CDM database. query_embedding : ndarray - The embedding vector to search with. Expected shape is (q, dimension) - where q is the number of query vectors and dimension is the size of the embedding space for the model. + Query embedding matrix of shape ``(Q, D)``. metric_type : MetricType - The similarity or distance metric to use for nearest neighbor search. This should be compatible with the index type used by the model. + Similarity or distance metric for nearest-neighbor search. concept_filter : Optional[EmbeddingConceptFilter], optional - Optional constraints to apply when retrieving nearest concepts. This allows filtering by concept IDs, domains, vocabularies, or standard flags. By default, no additional filtering is applied. + Optional filter restricting which OMOP concepts are considered. k : int, optional - K nearest neighbors to return for each query vector. Default is 10. + Number of nearest matches to return per query. + _model_record : EmbeddingModelRecord + Internal registered-model record injected by ``@require_registered_model``. Returns ------- diff --git a/src/omop_emb/backends/faiss/faiss_backend.py b/src/omop_emb/backends/faiss/faiss_backend.py index ca756bc..e197255 100644 --- a/src/omop_emb/backends/faiss/faiss_backend.py +++ b/src/omop_emb/backends/faiss/faiss_backend.py @@ -139,27 +139,19 @@ def upsert_embeddings( Parameters ---------- model_name : str - The unique identifier or name of the embedding model (e.g., - 'text-embedding-3-small'). + Registered name of the embedding model. index_type : IndexType - The type of vector index used to store the embeddings. - model_record : EmbeddingModelRecord - A record object containing metadata, dimensions, and configuration - specific to the embedding model being processed. + Storage index type used for this model's embeddings. session : sqlalchemy.orm.Session - The active database session used for transactional persistence and - model metadata updates. + SQLAlchemy session bound to the OMOP CDM database. concept_ids : Sequence[int] - A sequence of OMOP standard concept IDs corresponding to the - ordered rows in the embeddings array. + Concept IDs aligned with the rows of ``embeddings``. embeddings : numpy.ndarray - A 2D array of shape (n_concepts, n_dimensions) containing the - generated vector representations. + Embedding matrix of shape ``(n_concepts, D)``. + _model_record : EmbeddingModelRecord + Internal registered-model record injected by ``@require_registered_model``. metric_type : Optional[MetricType] - The similarity metric type (e.g., COSINE, EUCLIDEAN) that should be - associated with the stored embeddings for accurate nearest neighbor - search behavior. If None, the index will not be built and the raw - embeddings are only stored in the HDF5 file for later use. + Optional metric used when creating/updating the FAISS index. Returns ------- diff --git a/src/omop_emb/backends/pgvector/pgvector_backend.py b/src/omop_emb/backends/pgvector/pgvector_backend.py index c76271d..548cb11 100644 --- a/src/omop_emb/backends/pgvector/pgvector_backend.py +++ b/src/omop_emb/backends/pgvector/pgvector_backend.py @@ -55,6 +55,24 @@ def upsert_embeddings( embeddings: ndarray, _model_record: EmbeddingModelRecord, ) -> None: + """ + Insert or update embeddings for OMOP concept IDs. + + Parameters + ---------- + model_name : str + Registered name of the embedding model. + index_type : IndexType + Storage index type used for this model's embeddings. + session : Session + SQLAlchemy session bound to the OMOP CDM database. + concept_ids : Sequence[int] + Concept IDs aligned with the rows of ``embeddings``. + embeddings : ndarray + Embedding matrix of shape ``(n_concepts, D)``. + _model_record : EmbeddingModelRecord + Internal registered-model record injected by ``@require_registered_model``. + """ concept_id_tuple = tuple(concept_ids) self.validate_embeddings_and_concept_ids( @@ -91,6 +109,22 @@ def get_embeddings_by_concept_ids( concept_ids: Sequence[int], _model_record: EmbeddingModelRecord, ) -> Mapping[int, Sequence[float]]: + """ + Return embeddings for the requested concept IDs. + + Parameters + ---------- + model_name : str + Registered name of the embedding model. + index_type : IndexType + Storage index type used for this model's embeddings. + session : Session + SQLAlchemy session bound to the OMOP CDM database. + concept_ids : Sequence[int] + Concept IDs to fetch from backend storage. + _model_record : EmbeddingModelRecord + Internal registered-model record injected by ``@require_registered_model``. + """ concept_id_tuple = tuple(concept_ids) if not concept_id_tuple: @@ -130,6 +164,28 @@ def get_nearest_concepts( k: int = 10, _model_record: EmbeddingModelRecord, ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: + """ + Return nearest stored concepts for query embeddings. + + Parameters + ---------- + model_name : str + Registered name of the embedding model. + index_type : IndexType + Storage index type used for this model's embeddings. + session : Session + SQLAlchemy session bound to the OMOP CDM database. + query_embeddings : ndarray + Query embedding matrix of shape ``(Q, D)``. + metric_type : MetricType + Similarity or distance metric for nearest-neighbor search. + concept_filter : Optional[EmbeddingConceptFilter], optional + Optional filter restricting which OMOP concepts are considered. + k : int, optional + Number of nearest matches to return per query. + _model_record : EmbeddingModelRecord + Internal registered-model record injected by ``@require_registered_model``. + """ embedding_table = self.get_embedding_table( model_name=model_name, index_type=index_type, diff --git a/src/omop_emb/config.py b/src/omop_emb/config.py index 2ae83de..3650e27 100644 --- a/src/omop_emb/config.py +++ b/src/omop_emb/config.py @@ -2,7 +2,6 @@ from typing import Dict, Tuple ENV_OMOP_EMB_BACKEND = "OMOP_EMB_BACKEND" -ENV_OMOP_EMB_FAISS_INDEX_DIR = "OMOP_EMB_FAISS_INDEX_DIR" ENV_BASE_STORAGE_DIR = "OMOP_EMB_BASE_STORAGE_DIR" # For backends that use file-based storage, e.g. FAISS with on-disk indices or registry metadata storage From 2a3a4f0813cc6a5c7df0375745e91ddda0ab56e5 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Wed, 8 Apr 2026 21:40:38 +0000 Subject: [PATCH 08/19] Fix spelling in Docs --- README.md | 8 ++++---- docs/usage/backend-selection.md | 2 +- docs/usage/cli.md | 2 +- mkdocs.yml | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 42a7d2a..822a0cd 100644 --- a/README.md +++ b/README.md @@ -40,11 +40,11 @@ Extended documentation can be found [here](https://AustralianCancerDataNetwork.g # Project Roadmap -- [x] Interface for postgres storage of vectors +- [x] Interface for PostgreSQL storage of vectors - [x] Interface for FAISS storage of embeddings -- [ ] Extensive unit testing - - [ ] Backend testing - - [ ] Corruption and restoration of DB testing +- [x] Extensive unit testing + - [x] Backend testing + - [x] Corruption and restoration of DB testing - [ ] Support non-Flat indices for each backend - [ ] `faiss` GPU support - [ ] [`pgvectorscale`](https://github.com/timescale/pgvectorscale) support diff --git a/docs/usage/backend-selection.md b/docs/usage/backend-selection.md index ca5daff..394ab8c 100644 --- a/docs/usage/backend-selection.md +++ b/docs/usage/backend-selection.md @@ -8,7 +8,7 @@ whatever happens to be installed. The current backend factory recognizes: -- `pgvector`: The [pgvector](https://github.com/pgvector/pgvector) extension to a standard postgres database to store embeddings directly in the database. +- `pgvector`: The [pgvector](https://github.com/pgvector/pgvector) extension to a standard PostgreSQL database to store embeddings directly in the database. - `faiss`: The [FAISS](https://github.com/facebookresearch/faiss) storage solution for on-disk storage. There is no implicit default backend name. You must pass one explicitly or set diff --git a/docs/usage/cli.md b/docs/usage/cli.md index 1e5c71f..3584329 100644 --- a/docs/usage/cli.md +++ b/docs/usage/cli.md @@ -19,7 +19,7 @@ At present, the production CLI path is PostgreSQL-oriented and stores embeddings pip install "omop-emb[faiss]" ``` -- **Database**: Postgres implementation of OMOP CDM. See [`omop-graph` documentation](reference-missing) for information how to setup. +- **Database**: PostgreSQL implementation of OMOP CDM. See [`omop-graph` documentation](reference-missing) for information how to setup. - **Environment**: `OMOP_DATABASE_URL` must be exported or present in `.env` (e.g., `postgresql://user:pass@localhost:5432/omop`). - **Backend config**: set `OMOP_EMB_BACKEND` (`pgvector` or `faiss`) and optionally `OMOP_EMB_BASE_STORAGE_DIR`. - **Connectivity**: Access to an OpenAI-compatible embeddings endpoint. *Currently only Ollama supported*. diff --git a/mkdocs.yml b/mkdocs.yml index ddfdb98..d982d7b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -31,7 +31,7 @@ nav: - Home: index.md - Getting Started: - Installation: usage/installation.md - - Backend Selection: usage/backend-selection.md + - Backend Selection: usage/backend-selection.md - "CLI Reference": usage/cli.md plugins: From 2124784eed9cc30ba284ff158fc4852e4176d89f Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Wed, 8 Apr 2026 22:20:35 +0000 Subject: [PATCH 09/19] Add missing components from cli.py --- src/omop_emb/cli.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/omop_emb/cli.py b/src/omop_emb/cli.py index 5d042b1..abcf216 100644 --- a/src/omop_emb/cli.py +++ b/src/omop_emb/cli.py @@ -43,10 +43,9 @@ def add_embeddings( "--backend", help="Embedding backend to use. Can be replaced by the `OMOP_EMB_BACKEND` environment variable." )] = None, - - faiss_base_dir: Annotated[Optional[str], typer.Option( - "--faiss-base-dir", - help="Optional base directory for FAISS backend storage." + storage_base_dir: Annotated[Optional[str], typer.Option( + "--storage-base-dir", + help="Optional base directory for embedding backend storage. Reverts to `OMOP_EMB_BASE_STORAGE_DIR` environment variable if not provided, or defaults to $HOME/.omop_emb/ if neither is set. " )] = None, standard_only: Annotated[bool, typer.Option( "--standard-only", @@ -72,7 +71,7 @@ def add_embeddings( interface = EmbeddingInterface.from_backend_name( backend_name=backend_name, - faiss_base_dir=faiss_base_dir, + storage_base_dir=storage_base_dir, embedding_client=LLMClient( model=model, api_base=api_base, @@ -97,11 +96,13 @@ def add_embeddings( session=reader, model_name=model, concept_filter=concept_filter, + index_type=index_type ) concepts_without_embedding = interface.q_get_concepts_without_embedding( model_name=model, concept_filter=concept_filter, limit=num_embeddings, + index_type=index_type ) logger.info(f"Total concepts to process: {total_concepts}") @@ -117,6 +118,7 @@ def add_embeddings( concept_ids=tuple(batch_concepts.keys()), concept_texts=tuple(batch_concepts.values()), batch_size=batch_size, + index_type=index_type ) pbar.update(len(batch_concepts)) From d843089c477f097f5babd75b50124efc143ce495 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Wed, 8 Apr 2026 23:07:42 +0000 Subject: [PATCH 10/19] Interface validation for better accesibility of the library from omop-graph --- src/omop_emb/__init__.py | 12 ++++++- src/omop_emb/backends/factory.py | 7 ++-- src/omop_emb/config.py | 39 +++++++++++++++++++++++ src/omop_emb/interface.py | 15 +++++---- tests/test_interface_validation.py | 51 ++++++++++++++++++++++++++++++ 5 files changed, 113 insertions(+), 11 deletions(-) create mode 100644 tests/test_interface_validation.py diff --git a/src/omop_emb/__init__.py b/src/omop_emb/__init__.py index 6ce8b62..1b28060 100644 --- a/src/omop_emb/__init__.py +++ b/src/omop_emb/__init__.py @@ -1,6 +1,13 @@ from .interface import EmbeddingInterface from .backends.base import EmbeddingConceptFilter -from .config import BackendType, IndexType, MetricType +from .config import ( + BackendType, + IndexType, + MetricType, + parse_backend_type, + parse_index_type, + parse_metric_type, +) from .backends.factory import get_embedding_backend __all__ = [ @@ -9,5 +16,8 @@ "BackendType", "IndexType", "MetricType", + "parse_backend_type", + "parse_index_type", + "parse_metric_type", "get_embedding_backend", ] diff --git a/src/omop_emb/backends/factory.py b/src/omop_emb/backends/factory.py index e7fb670..577e7aa 100644 --- a/src/omop_emb/backends/factory.py +++ b/src/omop_emb/backends/factory.py @@ -7,6 +7,7 @@ from ..config import ( BackendType, ENV_OMOP_EMB_BACKEND, + parse_backend_type, ) from omop_emb.utils.errors import ( EmbeddingBackendConfigurationError, @@ -14,7 +15,7 @@ UnknownEmbeddingBackendError, ) -def normalize_backend_name(backend_name: Optional[str]) -> BackendType: +def normalize_backend_name(backend_name: Optional[str | BackendType]) -> BackendType: """ Normalize an embedding backend name from an explicit argument or env var. @@ -28,7 +29,7 @@ def normalize_backend_name(backend_name: Optional[str]) -> BackendType: raise AttributeError(f"No embedding backend specified. Provide an explicit backend_name or set the {ENV_OMOP_EMB_BACKEND} environment variable.") else: try: - backend_type = BackendType(backend_name) + backend_type = parse_backend_type(backend_name) except ValueError: raise UnknownEmbeddingBackendError( f"Unknown embedding backend {backend_name!r}. " @@ -38,7 +39,7 @@ def normalize_backend_name(backend_name: Optional[str]) -> BackendType: def get_embedding_backend( - backend_name: Optional[str] = None, + backend_name: Optional[str | BackendType] = None, storage_base_dir: Optional[str] = None, registry_db_name: Optional[str] = None, ) -> EmbeddingBackend: diff --git a/src/omop_emb/config.py b/src/omop_emb/config.py index 3650e27..3096dd2 100644 --- a/src/omop_emb/config.py +++ b/src/omop_emb/config.py @@ -39,6 +39,45 @@ class MetricType(StrEnum): HAMMING = "hamming" JACCARD = "jaccard" + +def parse_backend_type(value: str | BackendType) -> BackendType: + """Normalize a backend value to ``BackendType`` with a clear error message.""" + if isinstance(value, BackendType): + return value + try: + return BackendType(value) + except ValueError as exc: + raise ValueError( + f"Invalid backend type {value!r}. Expected one of " + f"{[member.value for member in BackendType]}." + ) from exc + + +def parse_index_type(value: str | IndexType) -> IndexType: + """Normalize an index value to ``IndexType`` with a clear error message.""" + if isinstance(value, IndexType): + return value + try: + return IndexType(value) + except ValueError as exc: + raise ValueError( + f"Invalid index type {value!r}. Expected one of " + f"{[member.value for member in IndexType]}." + ) from exc + + +def parse_metric_type(value: str | MetricType) -> MetricType: + """Normalize a metric value to ``MetricType`` with a clear error message.""" + if isinstance(value, MetricType): + return value + try: + return MetricType(value) + except ValueError as exc: + raise ValueError( + f"Invalid metric type {value!r}. Expected one of " + f"{[member.value for member in MetricType]}." + ) from exc + # TODO: Support non-flat indices in the future SUPPORTED_INDICES_AND_METRICS_PER_BACKEND: Dict[BackendType, Dict[IndexType, Tuple[MetricType, ...]]] = { #https://github.com/pgvector/pgvector?tab=readme-ov-file#querying diff --git a/src/omop_emb/interface.py b/src/omop_emb/interface.py index 1bba90e..bfaf719 100644 --- a/src/omop_emb/interface.py +++ b/src/omop_emb/interface.py @@ -16,7 +16,7 @@ ) from omop_emb.utils.embedding_utils import EmbeddingConceptFilter from omop_emb.model_registry import EmbeddingModelRecord -from .config import IndexType, MetricType +from .config import BackendType, IndexType, MetricType @dataclass @@ -47,7 +47,7 @@ def embedding_dim(self) -> Optional[int]: def from_backend_name( cls, embedding_client: Optional[LLMClient] = None, - backend_name: Optional[str] = None, + backend_name: Optional[str | BackendType] = None, storage_base_dir: Optional[str] = None, registry_db_name: Optional[str] = None, ) -> EmbeddingInterface: @@ -119,7 +119,7 @@ def get_nearest_concepts( index_type: IndexType, query_embedding: np.ndarray, *, - metric_type: MetricType | str, + metric_type: MetricType, concept_filter: Optional[EmbeddingConceptFilter] = None, k: int = 10, ) -> Tuple[Mapping[int, float], ...]: @@ -138,7 +138,7 @@ def get_nearest_concepts( query_embedding : ndarray The embedding vector to search with. Expected shape is (q, dimension) where q is the number of query vectors and dimension is the size of the embedding space for the model. - metric_type : MetricType, str + metric_type : MetricType The similarity or distance metric to use for nearest neighbor search. This must be compatible with the index type used by the database. concept_filter : Optional[EmbeddingConceptFilter], optional A filter to specify which concepts to consider as potential nearest neighbors. @@ -154,9 +154,10 @@ def get_nearest_concepts( Tuple[Mapping[int, float], ...] A tuple of dictionaries containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner dictionary contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, k) where q is the number of query vectors and k is the number of nearest neighbors returned per query. """ - if isinstance(metric_type, str): - metric_type = MetricType(metric_type) - assert isinstance(metric_type, MetricType), f"metric_type must be a MetricType enum or string. Got {type(metric_type)}" + if not isinstance(metric_type, MetricType): + raise TypeError( + f"metric_type must be MetricType, got {type(metric_type).__name__}." + ) nearest_concepts = self.backend.get_nearest_concepts( session=session, model_name=model_name, diff --git a/tests/test_interface_validation.py b/tests/test_interface_validation.py new file mode 100644 index 0000000..ec0d0c7 --- /dev/null +++ b/tests/test_interface_validation.py @@ -0,0 +1,51 @@ +"""Validation tests for strict EmbeddingInterface input contracts.""" + +from unittest.mock import Mock + +import numpy as np +import pytest + +from omop_emb.config import IndexType, MetricType +from omop_emb.interface import EmbeddingInterface + + +@pytest.mark.unit +class TestInterfaceValidation: + def test_get_nearest_concepts_requires_index_type(self): + """Index type is required as part of the strict core interface.""" + interface = EmbeddingInterface(backend=Mock(), embedding_client=Mock()) + kwargs = { + "session": Mock(), + "model_name": "test-model", + "query_embedding": np.zeros((1, 1), dtype=np.float32), + "metric_type": MetricType.COSINE, + } + + with pytest.raises(TypeError): + interface.get_nearest_concepts(**kwargs) + + def test_get_nearest_concepts_requires_metric_type(self): + """Metric type is required as part of the strict core interface.""" + interface = EmbeddingInterface(backend=Mock(), embedding_client=Mock()) + kwargs = { + "session": Mock(), + "model_name": "test-model", + "index_type": IndexType.FLAT, + "query_embedding": np.zeros((1, 1), dtype=np.float32), + } + + with pytest.raises(TypeError): + interface.get_nearest_concepts(**kwargs) + + def test_get_nearest_concepts_rejects_non_enum_metric_type(self): + """Core interface rejects non-MetricType values with a clear error.""" + interface = EmbeddingInterface(backend=Mock(), embedding_client=Mock()) + + with pytest.raises(TypeError, match="metric_type must be MetricType"): + interface.get_nearest_concepts( + session=Mock(), + model_name="test-model", + index_type=IndexType.FLAT, + query_embedding=np.zeros((1, 1), dtype=np.float32), + metric_type="cosine", # type: ignore[arg-type] + ) From 8b55454dd6fbdeb342eb88b6674431719b59fe33 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Thu, 9 Apr 2026 01:43:12 +0000 Subject: [PATCH 11/19] Import/Export CLI utils for pgvector storage to disk, storage in correct dir of embedding artefacts --- .gitignore | 1 + README.md | 2 +- docs/index.md | 2 +- docs/usage/backend-selection.md | 7 +- docs/usage/cli.md | 58 +++++++- docs/usage/installation.md | 4 +- src/omop_emb/backends/base.py | 4 +- src/omop_emb/cli.py | 237 ++++++++++++++++++++++++++++++-- 8 files changed, 298 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index d50942f..1383310 100644 --- a/.gitignore +++ b/.gitignore @@ -137,6 +137,7 @@ celerybeat.pid # Environments .env .envrc +.omop_emb/ .venv env/ venv/ diff --git a/README.md b/README.md index 822a0cd..f3dd4b3 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Notes: Common environment variables: - `OMOP_EMB_BACKEND`: backend name (`pgvector` or `faiss`) used by the backend factory. -- `OMOP_EMB_BASE_STORAGE_DIR`: local base directory for `omop-emb` artifacts, including local metadata (`metadata.db`) and FAISS files. +- `OMOP_EMB_BASE_STORAGE_DIR`: local base directory for `omop-emb` artifacts, including local metadata (`metadata.db`) and FAISS files. If unset, `omop-emb` defaults to `./.omop_emb` in the current working directory. - `OMOP_DATABASE_URL`: SQLAlchemy URL for the OMOP CDM database. Extended documentation can be found [here](https://AustralianCancerDataNetwork.github.io/omop-emb). diff --git a/docs/index.md b/docs/index.md index fc36fbd..01f4196 100644 --- a/docs/index.md +++ b/docs/index.md @@ -34,7 +34,7 @@ At runtime, backend choice should also be explicit. The intended direction is: Recommended runtime environment variables: - `OMOP_EMB_BACKEND` (`pgvector` or `faiss`) -- `OMOP_EMB_BASE_STORAGE_DIR` (base directory for local metadata and FAISS artifacts) +- `OMOP_EMB_BASE_STORAGE_DIR` (base directory for local metadata and FAISS artifacts; defaults to `./.omop_emb` in the current working directory if unset) - `OMOP_DATABASE_URL` (OMOP CDM database URL) diff --git a/docs/usage/backend-selection.md b/docs/usage/backend-selection.md index 394ab8c..d1f2c6d 100644 --- a/docs/usage/backend-selection.md +++ b/docs/usage/backend-selection.md @@ -26,11 +26,16 @@ Examples: ```bash export OMOP_EMB_BACKEND=pgvector export OMOP_EMB_BACKEND=faiss -export OMOP_EMB_BASE_STORAGE_DIR=$HOME/.omop_emb +export OMOP_EMB_BASE_STORAGE_DIR=$PWD/.omop_emb ``` You can also pass the backend name directly in Python. +Storage directory behavior: + +- If `OMOP_EMB_BASE_STORAGE_DIR` is unset and no explicit path is passed, `omop-emb` defaults to `./.omop_emb` in the current working directory. +- If a path includes `~`, it is expanded (for example `~/.omop_emb`). + ## Python factory The backend factory lives in `omop_emb.backends`: diff --git a/docs/usage/cli.md b/docs/usage/cli.md index 3584329..d268901 100644 --- a/docs/usage/cli.md +++ b/docs/usage/cli.md @@ -19,7 +19,7 @@ At present, the production CLI path is PostgreSQL-oriented and stores embeddings pip install "omop-emb[faiss]" ``` -- **Database**: PostgreSQL implementation of OMOP CDM. See [`omop-graph` documentation](reference-missing) for information how to setup. +- **Database**: PostgreSQL implementation of OMOP CDM. See [`omop-graph` documentation](https://AustralianCancerDataNetwork.github.io/omop-graph) for information how to setup. - **Environment**: `OMOP_DATABASE_URL` must be exported or present in `.env` (e.g., `postgresql://user:pass@localhost:5432/omop`). - **Backend config**: set `OMOP_EMB_BACKEND` (`pgvector` or `faiss`) and optionally `OMOP_EMB_BASE_STORAGE_DIR`. - **Connectivity**: Access to an OpenAI-compatible embeddings endpoint. *Currently only Ollama supported*. @@ -61,4 +61,58 @@ where `[OPTIONS]` are optional arguments that can be specified as described belo - `OMOP_DATABASE_URL`: OMOP CDM database connection string. - `OMOP_EMB_BACKEND`: backend selector used when `--backend` is not provided. -- `OMOP_EMB_BASE_STORAGE_DIR`: local storage root for metadata and file-based artifacts. +- `OMOP_EMB_BASE_STORAGE_DIR`: local storage root for metadata and file-based artifacts. If unset, `omop-emb` defaults to `./.omop_emb` in the current working directory. + +Paths that include `~` are expanded automatically. + +--- + +## `export-pgvector` + +Export pgvector embedding tables to CSV files plus a manifest so they can be restored later. + +### Usage +```bash +omop-emb export-pgvector --output-dir [OPTIONS] +``` + +### Options + +| Option | Short | Type | Default | Description | +| :--- | :--- | :--- | :--- | :--- | +| **`--output-dir`** | `-o` | `String` | **Required** | Directory where snapshot files are written. | +| **`--storage-base-dir`** | | `String` | `None` | Optional path to local metadata registry (`metadata.db`). If unset, falls back to `OMOP_EMB_BASE_STORAGE_DIR`, otherwise defaults to `./.omop_emb` in the current working directory. Paths with `~` are expanded. | +| **`--model`** | `-m` | `List[String]` | `None` | Optional model-name filter. Repeat to export specific models only. | +| **`--index-type`** | | `IndexType` | `None` | Optional index-type filter. | + +### Output + +The command writes: + +- `manifest.json`: snapshot metadata and table mapping +- One CSV per embedding table named `.csv` + +--- + +## `import-pgvector` + +Restore pgvector embedding tables from files previously created by `export-pgvector`. + +### Usage +```bash +omop-emb import-pgvector --input-dir [OPTIONS] +``` + +### Options + +| Option | Short | Type | Default | Description | +| :--- | :--- | :--- | :--- | :--- | +| **`--input-dir`** | `-i` | `String` | **Required** | Directory containing `manifest.json` and CSV files. | +| **`--storage-base-dir`** | | `String` | `None` | Optional path to local metadata registry (`metadata.db`). If unset, falls back to `OMOP_EMB_BASE_STORAGE_DIR`, otherwise defaults to `./.omop_emb` in the current working directory. Paths with `~` are expanded. | +| **`--replace`** | | `Boolean` | `False` | If set, truncate destination embedding tables before import. | +| **`--batch-size`** | `-b` | `Integer` | `5000` | Number of rows inserted per SQL batch. | + +### Notes + +- Import re-registers pgvector models into local metadata before loading rows. +- Import uses upsert semantics (`ON CONFLICT (concept_id) DO UPDATE`) unless `--replace` is set. diff --git a/docs/usage/installation.md b/docs/usage/installation.md index 07d7186..5819d3b 100644 --- a/docs/usage/installation.md +++ b/docs/usage/installation.md @@ -54,13 +54,15 @@ Examples: ```bash export OMOP_EMB_BACKEND=pgvector export OMOP_EMB_BACKEND=faiss -export OMOP_EMB_BASE_STORAGE_DIR=$HOME/.omop_emb +export OMOP_EMB_BASE_STORAGE_DIR=$PWD/.omop_emb ``` That avoids silent fallback between backend implementations. `OMOP_EMB_BASE_STORAGE_DIR` controls where `omop-emb` stores local metadata (`metadata.db`) and file-based backend artifacts (such as FAISS files). +If it is not set, `omop-emb` defaults to `./.omop_emb` in the current working directory. +If a provided path includes `~`, it is expanded automatically. ## Current database support caveat diff --git a/src/omop_emb/backends/base.py b/src/omop_emb/backends/base.py index d8b418d..3afd78d 100644 --- a/src/omop_emb/backends/base.py +++ b/src/omop_emb/backends/base.py @@ -68,7 +68,7 @@ class EmbeddingBackend(ABC, Generic[T]): to validate models, resolve concept metadata, and apply domain/vocabulary filters. """ - DEFAULT_BASE_STORAGE_DIR = str(Path.home() / ".omop_emb") + DEFAULT_BASE_STORAGE_DIR = ".omop_emb" def __init__( self, storage_base_dir: Optional[str | Path] = None, @@ -78,7 +78,7 @@ def __init__( # Local storage for model registry database and more storage_base_dir = storage_base_dir or os.getenv(ENV_BASE_STORAGE_DIR) or self.DEFAULT_BASE_STORAGE_DIR - self._storage_base_dir = Path(storage_base_dir) + self._storage_base_dir = Path(storage_base_dir).expanduser() self._storage_base_dir.mkdir(parents=True, exist_ok=True) diff --git a/src/omop_emb/cli.py b/src/omop_emb/cli.py index abcf216..8690b7e 100644 --- a/src/omop_emb/cli.py +++ b/src/omop_emb/cli.py @@ -4,9 +4,11 @@ from sqlalchemy.orm import Session from orm_loader.helpers import get_logger, configure_logging, create_db -from omop_alchemy.cdm.model.vocabulary import Concept from typing import Annotated, Optional +from pathlib import Path +import csv +import json import os from dotenv import load_dotenv from tqdm import tqdm @@ -14,11 +16,33 @@ from omop_emb.utils.embedding_utils import EmbeddingConceptFilter from omop_emb.interface import EmbeddingInterface -from omop_emb.config import IndexType +from omop_emb.config import IndexType, BackendType app = typer.Typer() logger = get_logger(__name__) +SNAPSHOT_MANIFEST_NAME = "manifest.json" + + +def _resolve_engine() -> sa.Engine: + engine_string = os.getenv('OMOP_DATABASE_URL') + if engine_string is None: + raise RuntimeError("OMOP_DATABASE_URL environment variable not set. Please set it in your .env file to point to your database.") + + engine = sa.create_engine(engine_string, future=True, echo=False) + assert engine.dialect.name == "postgresql", "Only PostgreSQL databases are supported for embedding storage with the current backends. Please check your `OMOP_DATABASE_URL` environment variable and ensure it points to a PostgreSQL database." + return engine + + +def _build_pgvector_interface(storage_base_dir: Optional[str]) -> EmbeddingInterface: + interface = EmbeddingInterface.from_backend_name( + backend_name=BackendType.PGVECTOR, + storage_base_dir=storage_base_dir, + ) + if interface.backend.backend_type != BackendType.PGVECTOR: + raise RuntimeError("Resolved embedding backend is not pgvector. Set --storage-base-dir and/or OMOP_EMB_BACKEND appropriately.") + return interface + @app.command() def add_embeddings( @@ -45,7 +69,7 @@ def add_embeddings( )] = None, storage_base_dir: Annotated[Optional[str], typer.Option( "--storage-base-dir", - help="Optional base directory for embedding backend storage. Reverts to `OMOP_EMB_BASE_STORAGE_DIR` environment variable if not provided, or defaults to $HOME/.omop_emb/ if neither is set. " + help="Optional base directory for embedding backend storage. Reverts to `OMOP_EMB_BASE_STORAGE_DIR` if not provided, or defaults to ./.omop_emb in the current working directory. Paths with `~` are expanded." )] = None, standard_only: Annotated[bool, typer.Option( "--standard-only", @@ -62,12 +86,7 @@ def add_embeddings( configure_logging() load_dotenv() - engine_string = os.getenv('OMOP_DATABASE_URL') - if engine_string is None: - raise RuntimeError("OMOP_DATABASE_URL environment variable not set. Please set it in your .env file to point to your database.") - - engine = sa.create_engine(engine_string, future=True, echo=False) - assert engine.dialect.name == "postgresql", "Only PostgreSQL databases are supported for embedding storage with the current backends. Please check your `OMOP_DATABASE_URL` environment variable and ensure it points to a PostgreSQL database." + engine = _resolve_engine() interface = EmbeddingInterface.from_backend_name( backend_name=backend_name, @@ -126,5 +145,205 @@ def add_embeddings( logger.info("Completed embedding generation and storage.") +@app.command() +def export_pgvector( + output_dir: Annotated[str, typer.Option( + "--output-dir", "-o", + help="Directory where pgvector snapshot files (CSV + manifest) are written.", + )], + storage_base_dir: Annotated[Optional[str], typer.Option( + "--storage-base-dir", + help="Optional base directory for omop-emb metadata registry (metadata.db). Reverts to `OMOP_EMB_BASE_STORAGE_DIR` if not provided, or defaults to ./.omop_emb in the current working directory. Paths with `~` are expanded.", + )] = None, + model: Annotated[Optional[list[str]], typer.Option( + "--model", "-m", + help="Optional model-name filter. Repeat to export specific models only.", + )] = None, + index_type: Annotated[Optional[IndexType], typer.Option( + "--index-type", + help="Optional index-type filter.", + )] = None, +): + """Export pgvector embedding tables to file for checkpoint/restore workflows.""" + configure_logging() + load_dotenv() + + engine = _resolve_engine() + interface = _build_pgvector_interface(storage_base_dir=storage_base_dir) + interface.initialise_store(engine) + + records = interface.backend.embedding_model_registry.get_registered_models_from_db( + backend_type=BackendType.PGVECTOR + ) + if records is None: + raise RuntimeError("No pgvector models found in local registry metadata. Nothing to export.") + + model_filter = set(model) if model else None + selected_records = [ + record for record in records + if (model_filter is None or record.model_name in model_filter) + and (index_type is None or record.index_type == index_type) + ] + if not selected_records: + raise RuntimeError("No pgvector models matched the provided filters.") + + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + + manifest = { + "format_version": 1, + "backend": BackendType.PGVECTOR.value, + "tables": [], + } + + with engine.connect() as conn: + for record in selected_records: + table_name = record.storage_identifier + quoted_table_name = f'"{table_name}"' + row_count = conn.execute(sa.text(f"SELECT COUNT(*) FROM {quoted_table_name}")) + n_rows = int(row_count.scalar_one()) + + csv_filename = f"{table_name}.csv" + csv_path = output_path / csv_filename + + query = sa.text( + f"SELECT concept_id, embedding::text AS embedding FROM {quoted_table_name} ORDER BY concept_id" + ) + result = conn.execution_options(stream_results=True).execute(query) + + with csv_path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.writer(handle) + writer.writerow(["concept_id", "embedding"]) + for row in result: + writer.writerow([int(row.concept_id), str(row.embedding)]) + + logger.info(f"Exported {n_rows} rows from {table_name} to {csv_path}") + manifest["tables"].append( + { + "model_name": record.model_name, + "index_type": record.index_type.value, + "dimensions": record.dimensions, + "storage_identifier": record.storage_identifier, + "metadata": dict(record.metadata), + "rows": n_rows, + "file": csv_filename, + } + ) + + manifest_path = output_path / SNAPSHOT_MANIFEST_NAME + manifest_path.write_text(json.dumps(manifest, indent=2), encoding="utf-8") + logger.info(f"Wrote pgvector snapshot manifest to {manifest_path}") + + +@app.command() +def import_pgvector( + input_dir: Annotated[str, typer.Option( + "--input-dir", "-i", + help="Directory containing pgvector snapshot files and manifest.json.", + )], + storage_base_dir: Annotated[Optional[str], typer.Option( + "--storage-base-dir", + help="Optional base directory for omop-emb metadata registry (metadata.db). Reverts to `OMOP_EMB_BASE_STORAGE_DIR` if not provided, or defaults to ./.omop_emb in the current working directory. Paths with `~` are expanded.", + )] = None, + replace: Annotated[bool, typer.Option( + "--replace", + help="If set, truncate each destination embedding table before import.", + )] = False, + batch_size: Annotated[int, typer.Option( + "--batch-size", "-b", + help="Number of rows per INSERT batch.", + )] = 5000, +): + """Import pgvector embedding tables from a previously exported snapshot.""" + configure_logging() + load_dotenv() + + input_path = Path(input_dir).resolve() + manifest_path = input_path / SNAPSHOT_MANIFEST_NAME + if not manifest_path.is_file(): + raise FileNotFoundError(f"Snapshot manifest not found at {manifest_path}") + + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + if manifest.get("backend") != BackendType.PGVECTOR.value: + raise ValueError("Snapshot manifest backend is not pgvector.") + + table_specs = manifest.get("tables") + if not isinstance(table_specs, list) or not table_specs: + raise ValueError("Snapshot manifest does not contain any tables.") + + engine = _resolve_engine() + interface = _build_pgvector_interface(storage_base_dir=storage_base_dir) + + for table_spec in table_specs: + model_name = str(table_spec["model_name"]) + index_type_enum = IndexType(str(table_spec["index_type"])) + dimensions = int(table_spec["dimensions"]) + expected_table_name = str(table_spec["storage_identifier"]) + metadata = table_spec.get("metadata") or {} + + interface.register_model( + engine=engine, + model_name=model_name, + dimensions=dimensions, + index_type=index_type_enum, + metadata=metadata, + ) + actual_table_name = interface.get_model_table_name( + model_name=model_name, + index_type=index_type_enum, + ) + if actual_table_name != expected_table_name: + raise RuntimeError( + f"Storage identifier mismatch for model '{model_name}': expected '{expected_table_name}', got '{actual_table_name}'." + ) + + interface.initialise_store(engine) + + with Session(engine) as session: + for table_spec in table_specs: + table_name = str(table_spec["storage_identifier"]) + csv_file = input_path / str(table_spec["file"]) + if not csv_file.is_file(): + raise FileNotFoundError(f"Missing snapshot table file: {csv_file}") + + quoted_table_name = f'"{table_name}"' + if replace: + session.execute(sa.text(f"TRUNCATE TABLE {quoted_table_name}")) + session.commit() + + insert_sql = sa.text( + f""" + INSERT INTO {quoted_table_name} (concept_id, embedding) + VALUES (:concept_id, CAST(:embedding AS vector)) + ON CONFLICT (concept_id) DO UPDATE + SET embedding = EXCLUDED.embedding + """ + ) + + total_rows = 0 + batch: list[dict[str, object]] = [] + with csv_file.open("r", encoding="utf-8", newline="") as handle: + reader = csv.DictReader(handle) + for row in reader: + batch.append( + { + "concept_id": int(row["concept_id"]), + "embedding": row["embedding"], + } + ) + if len(batch) >= batch_size: + session.execute(insert_sql, batch) + session.commit() + total_rows += len(batch) + batch = [] + + if batch: + session.execute(insert_sql, batch) + session.commit() + total_rows += len(batch) + + logger.info(f"Imported {total_rows} rows into {table_name} from {csv_file}") + + if __name__ == "__main__": app() From 8070bef6de36de2305f8612b8cc6d0cf72455e61 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Thu, 9 Apr 2026 03:46:09 +0000 Subject: [PATCH 12/19] Add tests and support legacy model_registry in CLI --- docs/usage/cli.md | 57 +++++- src/omop_emb/cli.py | 108 ++++++++++ .../model_registry/model_registry_manager.py | 11 +- tests/test_cli_pgvector_snapshot.py | 184 ++++++++++++++++++ 4 files changed, 355 insertions(+), 5 deletions(-) create mode 100644 tests/test_cli_pgvector_snapshot.py diff --git a/docs/usage/cli.md b/docs/usage/cli.md index d268901..13fb513 100644 --- a/docs/usage/cli.md +++ b/docs/usage/cli.md @@ -40,8 +40,6 @@ omop-emb add-embeddings --api-base --api-key [OPTIONS] where `[OPTIONS]` are optional arguments that can be specified as described below. -### Command Options - ### Command Options | Option | Short | Type | Default | Description | @@ -52,7 +50,7 @@ where `[OPTIONS]` are optional arguments that can be specified as described belo | **`--batch-size`** | `-b` | `Integer` | `100` | Number of concepts to process in each chunk. | | **`--model`** | `-m` | `String` | `text-embedding-3-small` | Name of the embedding model to use for generating vectors. | | **`--backend`** | | `Literal['pgvector', 'faiss']` | `None` | Embedding backend to use (can be replaced by `OMOP_EMB_BACKEND`). Requires the corresponding optional dependency. | -| **`--faiss-base-dir`** | | `String` | `None` | Optional base directory for FAISS backend storage. | +| **`--storage-base-dir`** | | `String` | `None` | Optional base directory for backend storage and local metadata registry (`metadata.db`). | | **`--standard-only`** | | `Boolean` | `False` | If set, only generate embeddings for OMOP standard concepts (`standard_concept = 'S'`). | | **`--vocabulary`** | | `List[String]` | `None` | Filter to embed concepts only from specific OMOP vocabularies. | | **`--num-embeddings`** | `-n` | `Integer` | `None` | Limit the number of concepts processed (useful for testing). | @@ -116,3 +114,56 @@ omop-emb import-pgvector --input-dir [OPTIONS] - Import re-registers pgvector models into local metadata before loading rows. - Import uses upsert semantics (`ON CONFLICT (concept_id) DO UPDATE`) unless `--replace` is set. + +--- + +## `migrate-legacy-pgvector-registry` + +Migrate legacy pgvector registry rows from a source database table into the local metadata registry (`metadata.db`). + +This command is intended for compatibility with older setups that kept registry metadata in the database instead of the local metadata store. + +### Usage +```bash +omop-emb migrate-legacy-pgvector-registry [OPTIONS] +``` + +### Options + +| Option | Type | Default | Description | +| :--- | :--- | :--- | :--- | +| **`--storage-base-dir`** | `String` | `None` | Optional path to local metadata registry location. If unset, falls back to `OMOP_EMB_BASE_STORAGE_DIR`, otherwise defaults to `./.omop_emb` in the current working directory. | +| **`--source-database-url`** | `String` | `OMOP_DATABASE_URL` | Source database URL containing the legacy registry table. | +| **`--legacy-table`** | `String` | `model_registry` | Name of the legacy registry table in the source database. | +| **`--dry-run`** | `Boolean` | `False` | Show what would be migrated without writing changes. | +| **`--drop-legacy-registry`** | `Boolean` | `False` | Drop the legacy table after successful migration. | + +### Recommended Migration Flow + +1. Validate what will migrate: + +```bash +omop-emb migrate-legacy-pgvector-registry --dry-run +``` + +2. Run the migration: + +```bash +omop-emb migrate-legacy-pgvector-registry +``` + +3. Optionally remove legacy table after verification: + +```bash +omop-emb migrate-legacy-pgvector-registry --drop-legacy-registry +``` + +### Field Mapping + +The migration command supports these legacy field names when reading rows: + +- model name: `model_name` +- dimensions: `dimensions` +- index type: `index_type` (fallback: `index_method`) +- storage identifier: `storage_identifier` (fallback: `table_name`) +- metadata: `details` (fallback: `metadata`) diff --git a/src/omop_emb/cli.py b/src/omop_emb/cli.py index 8690b7e..7baf4f1 100644 --- a/src/omop_emb/cli.py +++ b/src/omop_emb/cli.py @@ -24,6 +24,43 @@ SNAPSHOT_MANIFEST_NAME = "manifest.json" +def _load_legacy_rows(engine: sa.Engine, legacy_table: str) -> list[dict[str, object]]: + inspector = sa.inspect(engine) + if not inspector.has_table(legacy_table): + raise RuntimeError(f"Legacy table '{legacy_table}' does not exist in source database.") + + query = sa.text(f'SELECT * FROM "{legacy_table}"') + with engine.connect() as conn: + rows = conn.execute(query).mappings().all() + return [dict(row) for row in rows] + + +def _legacy_row_fields(row: dict[str, object]) -> tuple[str, int, IndexType, str, dict[str, object]]: + model_name_raw = row.get("model_name") + dimensions_raw = row.get("dimensions") + if model_name_raw is None or dimensions_raw is None: + raise ValueError("Legacy row missing required fields 'model_name' and/or 'dimensions'.") + + index_raw = row.get("index_type") or row.get("index_method") or IndexType.FLAT.value + index_type = IndexType(str(index_raw)) + + storage_identifier_raw = row.get("storage_identifier") or row.get("table_name") + if storage_identifier_raw is None: + raise ValueError("Legacy row missing required storage field ('storage_identifier' or 'table_name').") + storage_identifier = str(storage_identifier_raw) + + details_raw = row.get("details") or row.get("metadata") or {} + if isinstance(details_raw, dict): + metadata = dict(details_raw) + else: + metadata = {} + + model_name = str(model_name_raw) + dimensions = int(str(dimensions_raw)) + + return model_name, dimensions, index_type, storage_identifier, metadata + + def _resolve_engine() -> sa.Engine: engine_string = os.getenv('OMOP_DATABASE_URL') if engine_string is None: @@ -345,5 +382,76 @@ def import_pgvector( logger.info(f"Imported {total_rows} rows into {table_name} from {csv_file}") +@app.command() +def migrate_legacy_pgvector_registry( + storage_base_dir: Annotated[Optional[str], typer.Option( + "--storage-base-dir", + help="Optional base directory for omop-emb metadata registry (metadata.db). Reverts to `OMOP_EMB_BASE_STORAGE_DIR` if not provided, or defaults to ./.omop_emb in the current working directory. Paths with `~` are expanded.", + )] = None, + source_database_url: Annotated[Optional[str], typer.Option( + "--source-database-url", + help="Source database URL containing the legacy model_registry table. Defaults to OMOP_DATABASE_URL.", + )] = None, + legacy_table: Annotated[str, typer.Option( + "--legacy-table", + help="Name of the legacy table to migrate from.", + )] = "model_registry", + dry_run: Annotated[bool, typer.Option( + "--dry-run", + help="If set, report rows that would be migrated without writing to local metadata.", + )] = False, + drop_legacy_registry: Annotated[bool, typer.Option( + "--drop-legacy-registry", + help="If set, drop the legacy table after successful migration.", + )] = False, +): + """Migrate legacy pgvector registry rows into local metadata.db registry.""" + configure_logging() + load_dotenv() + + source_url = source_database_url or os.getenv("OMOP_DATABASE_URL") + if source_url is None: + raise RuntimeError("OMOP_DATABASE_URL is not set. Provide --source-database-url.") + + source_engine = sa.create_engine(source_url, future=True, echo=False) + interface = _build_pgvector_interface(storage_base_dir=storage_base_dir) + + legacy_rows = _load_legacy_rows(source_engine, legacy_table=legacy_table) + if not legacy_rows: + logger.info("No legacy registry rows found. Nothing to migrate.") + return + + logger.info(f"Found {len(legacy_rows)} legacy registry rows in '{legacy_table}'.") + + migrated = 0 + for row in legacy_rows: + model_name, dimensions, index_type, storage_identifier, metadata = _legacy_row_fields(row) + + if dry_run: + logger.info( + f"[DRY RUN] Would migrate model={model_name}, index={index_type.value}, " + f"dimensions={dimensions}, storage_identifier={storage_identifier}" + ) + migrated += 1 + continue + + interface.backend.embedding_model_registry.register_model( + model_name=model_name, + dimensions=dimensions, + backend_type=BackendType.PGVECTOR, + index_type=index_type, + metadata=metadata, + storage_identifier=storage_identifier, + ) + migrated += 1 + + logger.info(f"Migrated {migrated} legacy registry rows into local metadata registry.") + + if drop_legacy_registry and not dry_run: + with source_engine.begin() as conn: + conn.execute(sa.text(f'DROP TABLE IF EXISTS "{legacy_table}"')) + logger.info(f"Dropped legacy registry table '{legacy_table}'.") + + if __name__ == "__main__": app() diff --git a/src/omop_emb/model_registry/model_registry_manager.py b/src/omop_emb/model_registry/model_registry_manager.py index 18bd10b..947c9e4 100644 --- a/src/omop_emb/model_registry/model_registry_manager.py +++ b/src/omop_emb/model_registry/model_registry_manager.py @@ -67,6 +67,7 @@ def register_model( backend_type: BackendType, index_type: IndexType, metadata: Optional[Mapping[str, object]] = None, + storage_identifier: Optional[str] = None, ) -> EmbeddingModelRecord: """ Shared template method for model registration. @@ -110,13 +111,19 @@ def register_model( f"metadata. Reuse the existing model name or choose a new one.", conflict_field="metadata" ) + if storage_identifier is not None and existing_row.storage_identifier != storage_identifier: + raise ModelRegistrationConflictError( + f"Model '{model_name}' is already registered with storage identifier " + f"'{existing_row.storage_identifier}', not '{storage_identifier}'.", + conflict_field="storage_identifier" + ) return self._registry_entry_to_model_record(existing_row) metadata = metadata or {} safe_name = self.safe_model_name(model_name) - storage_name = self.storage_name( - safe_model_name=safe_name, + storage_name = storage_identifier or self.storage_name( + safe_model_name=safe_name, index_type=index_type, backend_type=backend_type ) diff --git a/tests/test_cli_pgvector_snapshot.py b/tests/test_cli_pgvector_snapshot.py new file mode 100644 index 0000000..bb3f2e8 --- /dev/null +++ b/tests/test_cli_pgvector_snapshot.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pytest +import sqlalchemy as sa +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session + +from omop_emb.cli import ( + export_pgvector, + import_pgvector, + migrate_legacy_pgvector_registry, +) +from omop_emb.config import BackendType, IndexType +from omop_emb.interface import EmbeddingInterface + +from .conftest import CONCEPTS, EMBEDDING_DIM, MODEL_NAME + + +@pytest.mark.pgvector +@pytest.mark.unit +def test_pgvector_export_import_roundtrip( + session: Session, + mock_llm_client, + temp_storage_dir: Path, +) -> None: + engine = session.get_bind() + assert isinstance(engine, Engine) + + source_storage = temp_storage_dir / "source_registry" + source_storage.mkdir(parents=True, exist_ok=True) + + interface = EmbeddingInterface.from_backend_name( + backend_name=BackendType.PGVECTOR, + storage_base_dir=str(source_storage), + embedding_client=mock_llm_client, + ) + interface.initialise_store(engine) + + interface.register_model( + engine=engine, + model_name=MODEL_NAME, + dimensions=EMBEDDING_DIM, + index_type=IndexType.FLAT, + metadata={"origin": "roundtrip-test"}, + ) + + concept_names = ["Hypertension", "Diabetes"] + concept_ids = tuple(CONCEPTS[name].concept_id for name in concept_names) + embeddings = mock_llm_client.embeddings(concept_names) + + interface.add_to_db( + session=session, + index_type=IndexType.FLAT, + concept_ids=concept_ids, + embeddings=embeddings, + model=MODEL_NAME, + ) + + engine_url = engine.url.render_as_string(hide_password=False) + + snapshot_dir = temp_storage_dir / "snapshot" + with pytest.MonkeyPatch.context() as mp: + mp.setenv("OMOP_DATABASE_URL", engine_url) + export_pgvector( + output_dir=str(snapshot_dir), + storage_base_dir=str(source_storage), + model=[MODEL_NAME], + index_type=IndexType.FLAT, + ) + + manifest_path = snapshot_dir / "manifest.json" + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + assert manifest["backend"] == BackendType.PGVECTOR.value + assert len(manifest["tables"]) == 1 + + storage_identifier = manifest["tables"][0]["storage_identifier"] + table_name = f'"{storage_identifier}"' + + with engine.begin() as conn: + conn.execute(sa.text(f"TRUNCATE TABLE {table_name}")) + + target_storage = temp_storage_dir / "target_registry" + target_storage.mkdir(parents=True, exist_ok=True) + + with pytest.MonkeyPatch.context() as mp: + mp.setenv("OMOP_DATABASE_URL", engine_url) + import_pgvector( + input_dir=str(snapshot_dir), + storage_base_dir=str(target_storage), + replace=True, + batch_size=2, + ) + + with engine.connect() as conn: + count = int(conn.execute(sa.text(f"SELECT COUNT(*) FROM {table_name}")).scalar_one()) + assert count == len(concept_ids) + + +@pytest.mark.pgvector +@pytest.mark.unit +def test_migrate_legacy_pgvector_registry( + session: Session, + temp_storage_dir: Path, +) -> None: + engine = session.get_bind() + assert isinstance(engine, Engine) + + legacy_table = "model_registry_legacy" + + with engine.begin() as conn: + conn.execute( + sa.text( + f""" + CREATE TABLE {legacy_table} ( + model_name TEXT NOT NULL, + dimensions INTEGER NOT NULL, + index_type TEXT NOT NULL, + table_name TEXT NOT NULL, + details JSONB + ) + """ + ) + ) + conn.execute( + sa.text( + f""" + INSERT INTO {legacy_table} (model_name, dimensions, index_type, table_name, details) + VALUES (:model_name, :dimensions, :index_type, :table_name, CAST(:details AS JSONB)) + """ + ), + { + "model_name": "legacy-model", + "dimensions": 1, + "index_type": "flat", + "table_name": "legacy_table_flat", + "details": '{"migrated": true}', + }, + ) + + source_url = engine.url.render_as_string(hide_password=False) + storage_dir = temp_storage_dir / "migrated_registry" + storage_dir.mkdir(parents=True, exist_ok=True) + + migrate_legacy_pgvector_registry( + storage_base_dir=str(storage_dir), + source_database_url=source_url, + legacy_table=legacy_table, + dry_run=True, + drop_legacy_registry=False, + ) + + interface = EmbeddingInterface.from_backend_name( + backend_name=BackendType.PGVECTOR, + storage_base_dir=str(storage_dir), + ) + migrated_before = interface.backend.embedding_model_registry.get_registered_models_from_db( + backend_type=BackendType.PGVECTOR, + model_name="legacy-model", + index_type=IndexType.FLAT, + ) + assert migrated_before is None + + migrate_legacy_pgvector_registry( + storage_base_dir=str(storage_dir), + source_database_url=source_url, + legacy_table=legacy_table, + dry_run=False, + drop_legacy_registry=True, + ) + + migrated = interface.backend.embedding_model_registry.get_registered_models_from_db( + backend_type=BackendType.PGVECTOR, + model_name="legacy-model", + index_type=IndexType.FLAT, + ) + assert migrated is not None + assert len(migrated) == 1 + assert migrated[0].storage_identifier == "legacy_table_flat" + + inspector = sa.inspect(engine) + assert not inspector.has_table(legacy_table) From 0ebfa44c3f241934bc16d30cc800394b88323828 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Thu, 9 Apr 2026 03:51:22 +0000 Subject: [PATCH 13/19] Little tweak of docs --- docs/index.md | 13 ++++++------- docs/usage/backend-selection.md | 2 +- docs/usage/installation.md | 11 ----------- mkdocs.yml | 2 +- 4 files changed, 8 insertions(+), 20 deletions(-) diff --git a/docs/index.md b/docs/index.md index 01f4196..3c79e15 100644 --- a/docs/index.md +++ b/docs/index.md @@ -33,18 +33,17 @@ At runtime, backend choice should also be explicit. The intended direction is: Recommended runtime environment variables: -- `OMOP_EMB_BACKEND` (`pgvector` or `faiss`) -- `OMOP_EMB_BASE_STORAGE_DIR` (base directory for local metadata and FAISS artifacts; defaults to `./.omop_emb` in the current working directory if unset) -- `OMOP_DATABASE_URL` (OMOP CDM database URL) +- **`OMOP_EMB_BACKEND`**: `pgvector` or `faiss` +- **`OMOP_EMB_BASE_STORAGE_DIR`**: base directory for local metadata and FAISS artifacts; defaults to `omop_emb/.omop_emb` the root direcotry of hte package. +- **`OMOP_DATABASE_URL`**: OMOP CDM database URL !!! info "Important caveats" - - - `omop-emb` depends on OMOP CDM database access for concept metadata and filtering. - - Current operational and test coverage is PostgreSQL-focused. + - `omop-emb` depends on OMOP CDM database access for concept metadata and filtering. + - Current operational and test coverage is PostgreSQL-focused. Extension planned in the future. ## Documentation overview - [Installation](usage/installation.md) -- [Backend Selection](usage/backend-selection.md) +- [Embedding storage backends](usage/backend-selection.md) - [CLI Reference](usage/cli.md) diff --git a/docs/usage/backend-selection.md b/docs/usage/backend-selection.md index d1f2c6d..9c7754f 100644 --- a/docs/usage/backend-selection.md +++ b/docs/usage/backend-selection.md @@ -1,4 +1,4 @@ -# Backend Selection +# Backend Selection for embeddings `omop-emb` now has a backend abstraction layer so embedding storage and retrieval can be selected explicitly instead of being inferred implicitly from diff --git a/docs/usage/installation.md b/docs/usage/installation.md index 5819d3b..0fa9670 100644 --- a/docs/usage/installation.md +++ b/docs/usage/installation.md @@ -64,14 +64,3 @@ That avoids silent fallback between backend implementations. If it is not set, `omop-emb` defaults to `./.omop_emb` in the current working directory. If a provided path includes `~`, it is expanded automatically. -## Current database support caveat - -PostgreSQL-specific embedding dependencies are now optional, but the broader -system has not yet been tested against non-PostgreSQL database backends. - -So the current position is: - -- PostgreSQL embedding infrastructure is optional -- a database backend is still always required -- database backends other than PostgreSQL should currently be treated as - unverified diff --git a/mkdocs.yml b/mkdocs.yml index d982d7b..e87e07e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -31,7 +31,7 @@ nav: - Home: index.md - Getting Started: - Installation: usage/installation.md - - Backend Selection: usage/backend-selection.md + - Embedding Storage: usage/backend-selection.md - "CLI Reference": usage/cli.md plugins: From 411b3ea9f97ede90f0c4369814a4960279b5b813 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Thu, 9 Apr 2026 05:02:45 +0000 Subject: [PATCH 14/19] Have domains in add_embeddings (should've been there ages ago) --- src/omop_emb/cli.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/omop_emb/cli.py b/src/omop_emb/cli.py index 7baf4f1..e7506d0 100644 --- a/src/omop_emb/cli.py +++ b/src/omop_emb/cli.py @@ -116,6 +116,10 @@ def add_embeddings( "--vocabulary", help="Optional vocabulary filter. Repeat the option to embed concepts only from specific OMOP vocabularies." )] = None, + domains: Annotated[Optional[list[str]], typer.Option( + "--domain", + help="Optional domain filter. Repeat the option to embed concepts only from specific OMOP domains." + )] = None, num_embeddings: Annotated[Optional[int], typer.Option( "--num-embeddings", "-n", help="If set, limits the number of concepts for which embeddings are generated. Useful for testing and development to speed up the embedding generation step.")] = None, @@ -140,6 +144,7 @@ def add_embeddings( concept_filter = EmbeddingConceptFilter( require_standard=standard_only, + domains=tuple(domains) if domains else None, vocabularies=tuple(vocabularies) if vocabularies else None, ) From 12b1a207093bc8e19d85e719a42e93534ba308bc Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Fri, 10 Apr 2026 04:36:27 +0000 Subject: [PATCH 15/19] Force absolute path so it can be accurately traced where the metadata is stored --- src/omop_emb/backends/base.py | 17 ++++++++++++++--- tests/test_config.py | 5 +++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/omop_emb/backends/base.py b/src/omop_emb/backends/base.py index 3afd78d..f438175 100644 --- a/src/omop_emb/backends/base.py +++ b/src/omop_emb/backends/base.py @@ -68,7 +68,7 @@ class EmbeddingBackend(ABC, Generic[T]): to validate models, resolve concept metadata, and apply domain/vocabulary filters. """ - DEFAULT_BASE_STORAGE_DIR = ".omop_emb" + DEFAULT_BASE_STORAGE_DIR = Path.home() / ".omop_emb" def __init__( self, storage_base_dir: Optional[str | Path] = None, @@ -77,8 +77,19 @@ def __init__( super().__init__() # Local storage for model registry database and more - storage_base_dir = storage_base_dir or os.getenv(ENV_BASE_STORAGE_DIR) or self.DEFAULT_BASE_STORAGE_DIR - self._storage_base_dir = Path(storage_base_dir).expanduser() + storage_base_dir = ( + storage_base_dir + or os.getenv(ENV_BASE_STORAGE_DIR) + or self.DEFAULT_BASE_STORAGE_DIR + ) + resolved_storage_base_dir = Path(storage_base_dir).expanduser() + if not resolved_storage_base_dir.is_absolute(): + raise ValueError( + "storage_base_dir must be an absolute path. " + f"Got '{storage_base_dir}'." + ) + + self._storage_base_dir = resolved_storage_base_dir self._storage_base_dir.mkdir(parents=True, exist_ok=True) diff --git a/tests/test_config.py b/tests/test_config.py index eadcb91..00c29be 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -19,6 +19,11 @@ def test_get_faiss_backend(self, tmp_path): """Test getting FAISS backend.""" backend = get_embedding_backend("faiss", storage_base_dir=str(tmp_path)) assert backend.backend_type == BackendType.FAISS + + def test_get_backend_rejects_relative_storage_base_dir(self): + """Relative storage directories are rejected to avoid cwd-dependent registry paths.""" + with pytest.raises(ValueError, match="absolute path"): + get_embedding_backend("faiss", storage_base_dir=".operator") def test_faiss_supports_flat_index(self): """Test FAISS supports FLAT index.""" From ed32c77f0da022f7986c854133bd07ee26c0795e Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Fri, 10 Apr 2026 04:37:44 +0000 Subject: [PATCH 16/19] Convenience method to set everything up --- src/omop_emb/interface.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/omop_emb/interface.py b/src/omop_emb/interface.py index bfaf719..08d5ac1 100644 --- a/src/omop_emb/interface.py +++ b/src/omop_emb/interface.py @@ -60,6 +60,23 @@ def from_backend_name( embedding_client=embedding_client, ) + def setup_and_register_model( + self, + engine: Engine, + model_name: str, + dimensions: int, + index_type: IndexType, + metadata: Mapping[str, object] = {}, + ) -> None: + self.initialise_store(engine) + self.register_model( + engine=engine, + model_name=model_name, + dimensions=dimensions, + index_type=index_type, + metadata=metadata, + ) + def initialise_store(self, engine: Engine) -> None: self.backend.initialise_store(engine) From 6713ce6ba2660c6a64b65817868b3518ccd28afa Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Fri, 10 Apr 2026 04:37:51 +0000 Subject: [PATCH 17/19] Fix spelling mistake --- src/omop_emb/interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/omop_emb/interface.py b/src/omop_emb/interface.py index 08d5ac1..577570f 100644 --- a/src/omop_emb/interface.py +++ b/src/omop_emb/interface.py @@ -179,7 +179,7 @@ def get_nearest_concepts( session=session, model_name=model_name, index_type=index_type, - query_embedding=query_embedding, + query_embeddings=query_embedding, concept_filter=concept_filter, metric_type=metric_type, k=k From 4fa9f4ce3c7c49a3b068b45d992b1dedad7548f6 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Fri, 10 Apr 2026 04:38:56 +0000 Subject: [PATCH 18/19] Use convenience method --- src/omop_emb/cli.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/omop_emb/cli.py b/src/omop_emb/cli.py index e7506d0..ef4664d 100644 --- a/src/omop_emb/cli.py +++ b/src/omop_emb/cli.py @@ -150,7 +150,12 @@ def add_embeddings( # Ensure OMOP metadata tables exist, then initialize the embedding store. create_db(engine) - interface.initialise_store(engine) + interface.setup_and_register_model( + engine=engine, + model_name=model, + dimensions=embedding_dim, + index_type=index_type, + ) with Session(engine) as reader, Session(engine) as writer: total_concepts = num_embeddings or interface.get_concepts_without_embedding_count( From 860a4556537ff4846153b479642aa8b02957e4e1 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Fri, 10 Apr 2026 04:39:59 +0000 Subject: [PATCH 19/19] Centralised logging configuration in CLI --- src/omop_emb/cli.py | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/src/omop_emb/cli.py b/src/omop_emb/cli.py index ef4664d..2f4c4e9 100644 --- a/src/omop_emb/cli.py +++ b/src/omop_emb/cli.py @@ -3,13 +3,14 @@ import sqlalchemy as sa from sqlalchemy.orm import Session -from orm_loader.helpers import get_logger, configure_logging, create_db +from orm_loader.helpers import get_logger, create_db from typing import Annotated, Optional from pathlib import Path import csv import json import os +import logging from dotenv import load_dotenv from tqdm import tqdm import typer @@ -24,6 +25,19 @@ SNAPSHOT_MANIFEST_NAME = "manifest.json" +def configure_logging_level(verbosity: int) -> None: + """Configure global logging based on CLI verbosity flags.""" + level_map = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} + log_level = level_map.get(min(verbosity, 2), logging.DEBUG) + + logging.basicConfig( + level=log_level, + format="%(asctime)s | %(name)s | %(levelname)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + ) + + def _load_legacy_rows(engine: sa.Engine, legacy_table: str) -> list[dict[str, object]]: inspector = sa.inspect(engine) if not inspector.has_table(legacy_table): @@ -123,8 +137,12 @@ def add_embeddings( num_embeddings: Annotated[Optional[int], typer.Option( "--num-embeddings", "-n", help="If set, limits the number of concepts for which embeddings are generated. Useful for testing and development to speed up the embedding generation step.")] = None, + verbosity: Annotated[int, typer.Option( + "--verbose", "-v", count=True, + help="Increase verbosity (up to two levels)" + )] = 0, ): - configure_logging() + configure_logging_level(verbosity) load_dotenv() engine = _resolve_engine() @@ -210,9 +228,13 @@ def export_pgvector( "--index-type", help="Optional index-type filter.", )] = None, + verbosity: Annotated[int, typer.Option( + "--verbose", "-v", count=True, + help="Increase verbosity (up to two levels)" + )] = 0, ): """Export pgvector embedding tables to file for checkpoint/restore workflows.""" - configure_logging() + configure_logging_level(verbosity) load_dotenv() engine = _resolve_engine() @@ -300,9 +322,13 @@ def import_pgvector( "--batch-size", "-b", help="Number of rows per INSERT batch.", )] = 5000, + verbosity: Annotated[int, typer.Option( + "--verbose", "-v", count=True, + help="Increase verbosity (up to two levels)" + )] = 0, ): """Import pgvector embedding tables from a previously exported snapshot.""" - configure_logging() + configure_logging_level(verbosity) load_dotenv() input_path = Path(input_dir).resolve() @@ -414,9 +440,13 @@ def migrate_legacy_pgvector_registry( "--drop-legacy-registry", help="If set, drop the legacy table after successful migration.", )] = False, + verbosity: Annotated[int, typer.Option( + "--verbose", "-v", count=True, + help="Increase verbosity (up to two levels)" + )] = 0, ): """Migrate legacy pgvector registry rows into local metadata.db registry.""" - configure_logging() + configure_logging_level(verbosity) load_dotenv() source_url = source_database_url or os.getenv("OMOP_DATABASE_URL")