diff --git a/.gitignore b/.gitignore index d50942f..1383310 100644 --- a/.gitignore +++ b/.gitignore @@ -137,6 +137,7 @@ celerybeat.pid # Environments .env .envrc +.omop_emb/ .venv env/ venv/ diff --git a/README.md b/README.md index 949bbcd..f3dd4b3 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,50 @@ # omop-emb Embedding layer for OMOP CDM. +`omop-emb` now separates model metadata from embedding storage: + +- model metadata is stored locally in SQLite (`metadata.db`) +- embedding vectors are stored by the selected backend (`pgvector` or `faiss`) +- OMOP concept metadata remains in the OMOP CDM database + ## Installation `omop-emb` now exposes backend-specific optional dependencies so installation can match the embedding backend you actually intend to use. ```bash -pip install "omop-emb[postgres]" +pip install "omop-emb[pgvector]" pip install "omop-emb[faiss]" pip install "omop-emb[all]" ``` Notes: -- `postgres` installs the PostgreSQL/pgvector dependencies. +- `pgvector` installs the PostgreSQL/pgvector dependencies. - `faiss` installs the FAISS-based backend dependencies. This currently only includes CPU support - `all` installs both backend stacks for development or mixed environments. - A plain `pip install omop-emb` installs the shared core package only. -- PostgreSQL-specific embedding dependencies are now optional, but `omop-emb` - still requires some database backend for OMOP access and model registration. +- PostgreSQL-specific embedding dependencies are optional, but `omop-emb` + still requires OMOP CDM database access. - Non-PostgreSQL database backends have not yet been tested. +## Runtime Configuration + +Common environment variables: + +- `OMOP_EMB_BACKEND`: backend name (`pgvector` or `faiss`) used by the backend factory. +- `OMOP_EMB_BASE_STORAGE_DIR`: local base directory for `omop-emb` artifacts, including local metadata (`metadata.db`) and FAISS files. If unset, `omop-emb` defaults to `./.omop_emb` in the current working directory. +- `OMOP_DATABASE_URL`: SQLAlchemy URL for the OMOP CDM database. + Extended documentation can be found [here](https://AustralianCancerDataNetwork.github.io/omop-emb). # Project Roadmap -- [x] Interface for postgres storage of vectors +- [x] Interface for PostgreSQL storage of vectors - [x] Interface for FAISS storage of embeddings -- [ ] Extensive unit testing - - [ ] Backend testing - - [ ] Corruption and restoration of DB testing +- [x] Extensive unit testing + - [x] Backend testing + - [x] Corruption and restoration of DB testing - [ ] Support non-Flat indices for each backend - [ ] `faiss` GPU support - [ ] [`pgvectorscale`](https://github.com/timescale/pgvectorscale) support diff --git a/docs/index.md b/docs/index.md index c62c5af..3c79e15 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,11 +5,12 @@ The package currently supports: - dynamic embedding model registration - - multiple embedding models can be stored in the respective backend + - model metadata is stored locally in SQLite (`metadata.db`) + - multiple embedding models can be tracked per backend and index type - embedding and lookup for OMOP concepts -- supports various backends with a PostgreSQL linker +- supports various storage backends - [pgvector](https://github.com/pgvector/pgvector): storage in the original OMOP database - - [FAISS](https://github.com/facebookresearch/faiss): efficient storage on disk for low-RAM applications + - [FAISS](https://github.com/facebookresearch/faiss): on-disk vector storage and index files - Extension to [`omop-alchemy`](https://AustralianCancerDataNetwork.github.io/OMOP_Alchemy/) to support new tables - CLI scripts to add embeddings to an already existing OMOP CDM @@ -18,7 +19,7 @@ The package currently supports: Install the backend you actually want to use: ```bash -pip install "omop-emb[postgres]" +pip install "omop-emb[pgvector]" pip install "omop-emb[faiss]" pip install "omop-emb[all]" ``` @@ -28,15 +29,21 @@ A plain `pip install omop-emb` installs only the shared core package. At runtime, backend choice should also be explicit. The intended direction is: - install-time choice via extras -- runtime choice via config such as `OMOP_EMB_BACKEND=postgres` or `OMOP_EMB_BACKEND=faiss` or passing it as an argument to the respective interface (e.g. see [CLI reference](usage/cli.md)) +- runtime choice via config such as `OMOP_EMB_BACKEND=pgvector` or `OMOP_EMB_BACKEND=faiss` or passing it as an argument to the respective interface (e.g. see [CLI reference](usage/cli.md)) +Recommended runtime environment variables: + +- **`OMOP_EMB_BACKEND`**: `pgvector` or `faiss` +- **`OMOP_EMB_BASE_STORAGE_DIR`**: base directory for local metadata and FAISS artifacts; defaults to `omop_emb/.omop_emb` the root direcotry of hte package. +- **`OMOP_DATABASE_URL`**: OMOP CDM database URL -!!! info "Important caveats" - - `omop-emb` depends on an OMOP PostgreSQL database for storage of embeddings (pgvector) or to keep track of already embedded concepts. +!!! info "Important caveats" + - `omop-emb` depends on OMOP CDM database access for concept metadata and filtering. + - Current operational and test coverage is PostgreSQL-focused. Extension planned in the future. ## Documentation overview - [Installation](usage/installation.md) -- [Backend Selection](usage/backend-selection.md) +- [Embedding storage backends](usage/backend-selection.md) - [CLI Reference](usage/cli.md) diff --git a/docs/usage/backend-selection.md b/docs/usage/backend-selection.md index 29aa9aa..9c7754f 100644 --- a/docs/usage/backend-selection.md +++ b/docs/usage/backend-selection.md @@ -1,4 +1,4 @@ -# Backend Selection +# Backend Selection for embeddings `omop-emb` now has a backend abstraction layer so embedding storage and retrieval can be selected explicitly instead of being inferred implicitly from @@ -8,10 +8,11 @@ whatever happens to be installed. The current backend factory recognizes: -- `pgvector`: The [pgvector](https://github.com/pgvector/pgvector) extension to a standard postgres database to store embeddings directly in the database. +- `pgvector`: The [pgvector](https://github.com/pgvector/pgvector) extension to a standard PostgreSQL database to store embeddings directly in the database. - `faiss`: The [FAISS](https://github.com/facebookresearch/faiss) storage solution for on-disk storage. -The default backend name is currently `postgres`. +There is no implicit default backend name. You must pass one explicitly or set +`OMOP_EMB_BACKEND`. ## Runtime selection @@ -23,12 +24,18 @@ The intended pattern is: Examples: ```bash -export OMOP_EMB_BACKEND=postgres +export OMOP_EMB_BACKEND=pgvector export OMOP_EMB_BACKEND=faiss +export OMOP_EMB_BASE_STORAGE_DIR=$PWD/.omop_emb ``` You can also pass the backend name directly in Python. +Storage directory behavior: + +- If `OMOP_EMB_BASE_STORAGE_DIR` is unset and no explicit path is passed, `omop-emb` defaults to `./.omop_emb` in the current working directory. +- If a path includes `~`, it is expanded (for example `~/.omop_emb`). + ## Python factory The backend factory lives in `omop_emb.backends`: @@ -36,7 +43,7 @@ The backend factory lives in `omop_emb.backends`: ```python from omop_emb.backends import get_embedding_backend -backend = get_embedding_backend("postgres") +backend = get_embedding_backend("pgvector") backend = get_embedding_backend("faiss") ``` @@ -78,10 +85,10 @@ At the moment: - the backend abstraction and backend factory exist - PostgreSQL and FAISS backend classes exist - the production CLI path still targets the PostgreSQL embedding workflow -- PostgreSQL-specific embedding dependencies are optional, but a database - backend is still required for OMOP access and model registration -- model registration is intended to remain shared and database-backed even when - FAISS is used for vector storage and retrieval +- PostgreSQL-specific embedding dependencies are optional, but OMOP database + access is still required for concept metadata +- model registration metadata is stored locally in SQLite (`metadata.db`) under + `OMOP_EMB_BASE_STORAGE_DIR` - database backends other than PostgreSQL have not yet been tested So this page documents the selection model and Python interface shape now, even diff --git a/docs/usage/cli.md b/docs/usage/cli.md index 2697cbc..13fb513 100644 --- a/docs/usage/cli.md +++ b/docs/usage/cli.md @@ -11,14 +11,17 @@ At present, the production CLI path is PostgreSQL-oriented and stores embeddings ## Prerequisites -- **Installation**: install the PostgreSQL backend dependencies: +- **Installation**: install the backend dependencies you plan to use: ```bash - pip install "omop-emb[postgres]" + pip install "omop-emb[pgvector]" + # or + pip install "omop-emb[faiss]" ``` -- **Database**: Postgres implementation of OMOP CDM. See [`omop-graph` documentation](reference-missing) for information how to setup. -- **Environment**: `OMOP_DATABASE_URL` must be exported or existing in the .env file (e.g., `postgresql://user:pass@localhost:5432/omop`). +- **Database**: PostgreSQL implementation of OMOP CDM. See [`omop-graph` documentation](https://AustralianCancerDataNetwork.github.io/omop-graph) for information how to setup. +- **Environment**: `OMOP_DATABASE_URL` must be exported or present in `.env` (e.g., `postgresql://user:pass@localhost:5432/omop`). +- **Backend config**: set `OMOP_EMB_BACKEND` (`pgvector` or `faiss`) and optionally `OMOP_EMB_BASE_STORAGE_DIR`. - **Connectivity**: Access to an OpenAI-compatible embeddings endpoint. *Currently only Ollama supported*. !!! note "Backend Scope" @@ -37,8 +40,6 @@ omop-emb add-embeddings --api-base --api-key [OPTIONS] where `[OPTIONS]` are optional arguments that can be specified as described below. -### Command Options - ### Command Options | Option | Short | Type | Default | Description | @@ -48,8 +49,121 @@ where `[OPTIONS]` are optional arguments that can be specified as described belo | **`--index-type`** | | `IndexType` | `FLAT` | The storage index for the embeddings for retrieval. Currently supported: `FLAT`. | | **`--batch-size`** | `-b` | `Integer` | `100` | Number of concepts to process in each chunk. | | **`--model`** | `-m` | `String` | `text-embedding-3-small` | Name of the embedding model to use for generating vectors. | -| **`--backend`** | | `Literal['pgvector', 'faiss']` | `None` | Embedding backend to use (can be replaced by `OMOP_EMB_BACKEND` env var). Requires the respective backend installed using `pip install omop-emb[pgvector or faiss]` | -| **`--faiss-base-dir`** | | `String` | `None` | Optional base directory for FAISS backend storage. | +| **`--backend`** | | `Literal['pgvector', 'faiss']` | `None` | Embedding backend to use (can be replaced by `OMOP_EMB_BACKEND`). Requires the corresponding optional dependency. | +| **`--storage-base-dir`** | | `String` | `None` | Optional base directory for backend storage and local metadata registry (`metadata.db`). | | **`--standard-only`** | | `Boolean` | `False` | If set, only generate embeddings for OMOP standard concepts (`standard_concept = 'S'`). | | **`--vocabulary`** | | `List[String]` | `None` | Filter to embed concepts only from specific OMOP vocabularies. | | **`--num-embeddings`** | `-n` | `Integer` | `None` | Limit the number of concepts processed (useful for testing). | + +## Environment Variables + +- `OMOP_DATABASE_URL`: OMOP CDM database connection string. +- `OMOP_EMB_BACKEND`: backend selector used when `--backend` is not provided. +- `OMOP_EMB_BASE_STORAGE_DIR`: local storage root for metadata and file-based artifacts. If unset, `omop-emb` defaults to `./.omop_emb` in the current working directory. + +Paths that include `~` are expanded automatically. + +--- + +## `export-pgvector` + +Export pgvector embedding tables to CSV files plus a manifest so they can be restored later. + +### Usage +```bash +omop-emb export-pgvector --output-dir [OPTIONS] +``` + +### Options + +| Option | Short | Type | Default | Description | +| :--- | :--- | :--- | :--- | :--- | +| **`--output-dir`** | `-o` | `String` | **Required** | Directory where snapshot files are written. | +| **`--storage-base-dir`** | | `String` | `None` | Optional path to local metadata registry (`metadata.db`). If unset, falls back to `OMOP_EMB_BASE_STORAGE_DIR`, otherwise defaults to `./.omop_emb` in the current working directory. Paths with `~` are expanded. | +| **`--model`** | `-m` | `List[String]` | `None` | Optional model-name filter. Repeat to export specific models only. | +| **`--index-type`** | | `IndexType` | `None` | Optional index-type filter. | + +### Output + +The command writes: + +- `manifest.json`: snapshot metadata and table mapping +- One CSV per embedding table named `.csv` + +--- + +## `import-pgvector` + +Restore pgvector embedding tables from files previously created by `export-pgvector`. + +### Usage +```bash +omop-emb import-pgvector --input-dir [OPTIONS] +``` + +### Options + +| Option | Short | Type | Default | Description | +| :--- | :--- | :--- | :--- | :--- | +| **`--input-dir`** | `-i` | `String` | **Required** | Directory containing `manifest.json` and CSV files. | +| **`--storage-base-dir`** | | `String` | `None` | Optional path to local metadata registry (`metadata.db`). If unset, falls back to `OMOP_EMB_BASE_STORAGE_DIR`, otherwise defaults to `./.omop_emb` in the current working directory. Paths with `~` are expanded. | +| **`--replace`** | | `Boolean` | `False` | If set, truncate destination embedding tables before import. | +| **`--batch-size`** | `-b` | `Integer` | `5000` | Number of rows inserted per SQL batch. | + +### Notes + +- Import re-registers pgvector models into local metadata before loading rows. +- Import uses upsert semantics (`ON CONFLICT (concept_id) DO UPDATE`) unless `--replace` is set. + +--- + +## `migrate-legacy-pgvector-registry` + +Migrate legacy pgvector registry rows from a source database table into the local metadata registry (`metadata.db`). + +This command is intended for compatibility with older setups that kept registry metadata in the database instead of the local metadata store. + +### Usage +```bash +omop-emb migrate-legacy-pgvector-registry [OPTIONS] +``` + +### Options + +| Option | Type | Default | Description | +| :--- | :--- | :--- | :--- | +| **`--storage-base-dir`** | `String` | `None` | Optional path to local metadata registry location. If unset, falls back to `OMOP_EMB_BASE_STORAGE_DIR`, otherwise defaults to `./.omop_emb` in the current working directory. | +| **`--source-database-url`** | `String` | `OMOP_DATABASE_URL` | Source database URL containing the legacy registry table. | +| **`--legacy-table`** | `String` | `model_registry` | Name of the legacy registry table in the source database. | +| **`--dry-run`** | `Boolean` | `False` | Show what would be migrated without writing changes. | +| **`--drop-legacy-registry`** | `Boolean` | `False` | Drop the legacy table after successful migration. | + +### Recommended Migration Flow + +1. Validate what will migrate: + +```bash +omop-emb migrate-legacy-pgvector-registry --dry-run +``` + +2. Run the migration: + +```bash +omop-emb migrate-legacy-pgvector-registry +``` + +3. Optionally remove legacy table after verification: + +```bash +omop-emb migrate-legacy-pgvector-registry --drop-legacy-registry +``` + +### Field Mapping + +The migration command supports these legacy field names when reading rows: + +- model name: `model_name` +- dimensions: `dimensions` +- index type: `index_type` (fallback: `index_method`) +- storage identifier: `storage_identifier` (fallback: `table_name`) +- metadata: `details` (fallback: `metadata`) diff --git a/docs/usage/installation.md b/docs/usage/installation.md index 80cd1cd..0fa9670 100644 --- a/docs/usage/installation.md +++ b/docs/usage/installation.md @@ -18,7 +18,7 @@ requires database-backed OMOP access and model registration. ## PostgreSQL backend ```bash -pip install "omop-emb[postgres]" +pip install "omop-emb[pgvector]" ``` Use this when you want the current pgvector/PostgreSQL-backed embedding store @@ -33,7 +33,7 @@ pip install "omop-emb[faiss]" Use this when you want the FAISS backend dependencies available. Even in this case, a database backend is still required for OMOP concept -metadata and model registration. +metadata access. ## Everything @@ -52,20 +52,15 @@ install-time. Examples: ```bash -export OMOP_EMB_BACKEND=postgres +export OMOP_EMB_BACKEND=pgvector export OMOP_EMB_BACKEND=faiss +export OMOP_EMB_BASE_STORAGE_DIR=$PWD/.omop_emb ``` That avoids silent fallback between backend implementations. -## Current database support caveat +`OMOP_EMB_BASE_STORAGE_DIR` controls where `omop-emb` stores local metadata +(`metadata.db`) and file-based backend artifacts (such as FAISS files). +If it is not set, `omop-emb` defaults to `./.omop_emb` in the current working directory. +If a provided path includes `~`, it is expanded automatically. -PostgreSQL-specific embedding dependencies are now optional, but the broader -system has not yet been tested against non-PostgreSQL database backends. - -So the current position is: - -- PostgreSQL embedding infrastructure is optional -- a database backend is still always required -- database backends other than PostgreSQL should currently be treated as - unverified diff --git a/mkdocs.yml b/mkdocs.yml index 25cd776..e87e07e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -31,6 +31,7 @@ nav: - Home: index.md - Getting Started: - Installation: usage/installation.md + - Embedding Storage: usage/backend-selection.md - "CLI Reference": usage/cli.md plugins: diff --git a/src/omop_emb/__init__.py b/src/omop_emb/__init__.py index 8f791a1..1b28060 100644 --- a/src/omop_emb/__init__.py +++ b/src/omop_emb/__init__.py @@ -1,6 +1,13 @@ from .interface import EmbeddingInterface from .backends.base import EmbeddingConceptFilter -from .backends.config import BackendType, IndexType, MetricType +from .config import ( + BackendType, + IndexType, + MetricType, + parse_backend_type, + parse_index_type, + parse_metric_type, +) from .backends.factory import get_embedding_backend __all__ = [ @@ -9,5 +16,8 @@ "BackendType", "IndexType", "MetricType", + "parse_backend_type", + "parse_index_type", + "parse_metric_type", "get_embedding_backend", ] diff --git a/src/omop_emb/backends/__init__.py b/src/omop_emb/backends/__init__.py index ede168c..4096d1c 100644 --- a/src/omop_emb/backends/__init__.py +++ b/src/omop_emb/backends/__init__.py @@ -1,18 +1,6 @@ from .base import ( EmbeddingBackend, ) -from .embedding_utils import ( - EmbeddingModelRecord, - EmbeddingBackendCapabilities, - EmbeddingConceptFilter, - NearestConceptMatch, -) -from .errors import ( - EmbeddingBackendConfigurationError, - EmbeddingBackendDependencyError, - EmbeddingBackendError, - UnknownEmbeddingBackendError, -) from .factory import ( get_embedding_backend, normalize_backend_name, @@ -20,14 +8,6 @@ __all__ = [ "EmbeddingBackend", - "EmbeddingBackendCapabilities", - "EmbeddingConceptFilter", - "EmbeddingModelRecord", - "NearestConceptMatch", - "EmbeddingBackendConfigurationError", - "EmbeddingBackendDependencyError", - "EmbeddingBackendError", - "UnknownEmbeddingBackendError", "get_embedding_backend", "normalize_backend_name", "FaissEmbeddingBackend", diff --git a/src/omop_emb/backends/base.py b/src/omop_emb/backends/base.py index 1765a63..f438175 100644 --- a/src/omop_emb/backends/base.py +++ b/src/omop_emb/backends/base.py @@ -2,8 +2,9 @@ from abc import ABC, abstractmethod from typing import Mapping, Optional, Sequence, Union, Type, TypeVar, Generic, Dict, Any, Callable, Tuple -import re +from pathlib import Path from functools import wraps +import os from numpy import ndarray from sqlalchemy import Engine, select, Integer, ForeignKey, func, Select, text @@ -11,26 +12,29 @@ from omop_alchemy.cdm.model.vocabulary import Concept -from .registry import ModelRegistry, ensure_model_registry_schema -from .config import BackendType, SUPPORTED_INDICES_AND_METRICS_PER_BACKEND, IndexType, MetricType -from .errors import ModelRegistrationConflictError -from .embedding_utils import ( - EmbeddingModelRecord, +from ..model_registry import EmbeddingModelRecord, ModelRegistryManager +from ..config import BackendType, IndexType, MetricType, ENV_BASE_STORAGE_DIR +from omop_emb.utils.embedding_utils import ( EmbeddingConceptFilter, NearestConceptMatch, - EmbeddingBackendCapabilities ) def require_registered_model(func: Callable) -> Callable: @wraps(func) - def wrapper(self, session: Session, model_name: str, *args, **kwargs) -> Any: - record = self.get_registered_model(session, model_name) + def wrapper( + self, + model_name: str, + index_type: IndexType, + *args, **kwargs + ) -> Any: + record = self.get_registered_model(model_name=model_name, index_type=index_type) if record is None: raise ValueError( - f"Embedding model '{model_name}' is not registered in the FAISS backend." + f"Embedding model '{model_name}' with index_type='{index_type.value}' is not registered in backend '{self.backend_type.value}'." ) - return func(self, session, model_name, record, *args, **kwargs) + kwargs["_model_record"] = record + return func(self, model_name, index_type, *args, **kwargs) return wrapper @@ -54,12 +58,7 @@ class EmbeddingBackend(ABC, Generic[T]): ------------ - Keep embedding generation separate from embedding persistence. - Support multiple storage/retrieval implementations behind one contract. - - Preserve the current operational needs of ``omop-graph``: - - model registration - - embedding population - - embedding lookup by concept ID - - similarity lookup over a set of concept IDs - - nearest-neighbor retrieval with OMOP-oriented filters + - Local storage of model registry metadata, but allow flexible storage of the actual embeddings (e.g., relational tables, FAISS indices, files in object storage, etc.) Notes ----- @@ -69,14 +68,51 @@ class EmbeddingBackend(ABC, Generic[T]): to validate models, resolve concept metadata, and apply domain/vocabulary filters. """ - def __init__(self): + DEFAULT_BASE_STORAGE_DIR = Path.home() / ".omop_emb" + def __init__( + self, + storage_base_dir: Optional[str | Path] = None, + registry_db_name: Optional[str] = None, + ): super().__init__() - self._model_cache: dict[str, Type[T]] = {} + + # Local storage for model registry database and more + storage_base_dir = ( + storage_base_dir + or os.getenv(ENV_BASE_STORAGE_DIR) + or self.DEFAULT_BASE_STORAGE_DIR + ) + resolved_storage_base_dir = Path(storage_base_dir).expanduser() + if not resolved_storage_base_dir.is_absolute(): + raise ValueError( + "storage_base_dir must be an absolute path. " + f"Got '{storage_base_dir}'." + ) + + self._storage_base_dir = resolved_storage_base_dir + self._storage_base_dir.mkdir(parents=True, exist_ok=True) + + + self._embedding_table_cache: dict[Tuple[str, BackendType, IndexType], Type[T]] = {} + self._embedding_model_registry = ModelRegistryManager( + base_dir=self.storage_base_dir, + db_file=registry_db_name, + ) @property - def model_cache(self) -> Dict[str, Type[T]]: + def storage_base_dir(self) -> Path: + """Base directory for any local storage needs of the backend, such as model registry metadata or file-based embedding storage.""" + return self._storage_base_dir + + @property + def embedding_table_cache(self) -> Dict[Tuple[str, BackendType, IndexType], Type[T]]: """In-memory cache of dynamically generated ORM classes for embedding tables.""" - return self._model_cache + return self._embedding_table_cache + + @property + def embedding_model_registry(self) -> ModelRegistryManager: + """Manager for embedding model metadata and registry operations.""" + return self._embedding_model_registry @property @abstractmethod @@ -87,87 +123,112 @@ def backend_type(self) -> BackendType: def backend_name(self) -> str: """Stable identifier for this backend implementation.""" return self.backend_type.value + + def _cache_model_record( + self, + engine: Engine, + model_record: EmbeddingModelRecord, + ) -> None: + """Cache the given model record in the embedding table cache. This is used to keep track of which models are registered and their associated metadata.""" + dynamic_table = self._create_storage_table(engine=engine, model_record=model_record) + storage_key = (model_record.model_name, self.backend_type, model_record.index_type) + self.embedding_table_cache[storage_key] = dynamic_table - @property + @abstractmethod - def capabilities(self) -> EmbeddingBackendCapabilities: - """Capability flags describing what this backend can do.""" + def _create_storage_table(self, engine: Engine, model_record: EmbeddingModelRecord) -> Type[T]: + """ + Backend-specific logic to create the dynamic SQLAlchemy class. + + Parameters + ---------- + engine : Engine + SQLAlchemy engine for OMOP CDM database storage. + model_record : EmbeddingModelRecord + Registered model metadata used to build backend storage. + """ def initialise_store(self, engine: Engine) -> None: """ - Prepare any required storage structures. + Initialise the model registry and populate the embedding table cache for existing models. + Can extend it to include any backend-specific setup steps (e.g., creating extensions, staging directories) as needed. - Examples: - - create registry tables - - warm caches from a registry - - create directories or sidecar files + Parameters + ---------- + engine : Engine + SQLAlchemy engine for OMOP CDM database storage. """ + self.pre_initialise_store(engine) - with Session(engine, expire_on_commit=False) as session: - existing_models = session.scalars(select(ModelRegistry).where(ModelRegistry.backend_type == self.backend_type)).all() - #for model_entry in existing_models: - # _heal_legacy_index_method(session, model_entry) - session.commit() - - for model_entry in existing_models: - if model_entry.model_name not in self.model_cache: - # Create the class and cache it - dynamic_table = self._create_storage_table(engine=engine, entry=model_entry) - self.model_cache[model_entry.model_name] = dynamic_table - - def list_registered_models(self, session: Session) -> Sequence[EmbeddingModelRecord]: - """Return all embedding models known to this backend.""" - return tuple( - self._record_from_registry_row(row) - for row in session.scalars( - select(ModelRegistry).where(ModelRegistry.backend_type == self.backend_type) - ).all() + registered_models = self.embedding_model_registry.get_registered_models_from_db( + backend_type=self.backend_type ) + if registered_models is None: + return + + for model_record in registered_models: + self._cache_model_record(engine=engine, model_record=model_record) + + def pre_initialise_store(self, engine: Engine) -> None: + """ + Hook for setup steps that run before store initialization. + + Parameters + ---------- + engine : Engine + SQLAlchemy engine for OMOP CDM database storage. + """ + pass def get_registered_model( self, - session: Session, model_name: str, + index_type: IndexType, ) -> Optional[EmbeddingModelRecord]: """Return metadata for one registered model, if present.""" - row = session.scalar( - select(ModelRegistry).where( - ModelRegistry.model_name == model_name, - ModelRegistry.backend_type == self.backend_type, - ) + registered_model = self.embedding_model_registry.get_registered_models_from_db( + backend_type=self.backend_type, + model_name=model_name, + index_type=index_type ) - if row is None: + if registered_model is None: return None - return self._record_from_registry_row(row) + + assert len(registered_model) == 1, ( + f"Expected exactly one registered model for name '{model_name}' and index type '{index_type.value}', but found {len(registered_model)}. This indicates a data integrity issue in the model registry database." + ) + return registered_model[0] + - def is_model_registered(self, session: Session, model_name: str) -> bool: + def is_model_registered(self, model_name: str, index_type: IndexType) -> bool: """Convenience wrapper over ``get_registered_model``.""" - return self.get_registered_model(session=session, model_name=model_name) is not None + return self.get_registered_model(model_name=model_name, index_type=index_type) is not None def get_concepts_without_embedding( self, session: Session, model_name: str, + index_type: IndexType, concept_filter: Optional[EmbeddingConceptFilter] = None, limit: Optional[int] = None, ) -> Mapping[int, str]: """Return concept IDs and names for concepts that do not have embeddings.""" - query = self.get_concepts_without_embedding_query( - session=session, + query = self.q_get_concepts_without_embedding( model_name=model_name, + index_type=index_type, concept_filter=concept_filter, limit=limit ) return {row.concept_id: row.concept_name for row in session.execute(query)} - def get_concepts_without_embedding_query( + def q_get_concepts_without_embedding( self, - session: Session, model_name: str, + index_type: IndexType, concept_filter: Optional[EmbeddingConceptFilter] = None, limit: Optional[int] = None, ) -> Select: - embedding_table = self._get_embedding_table(session=session, model_name=model_name) + embedding_table = self.get_embedding_table(model_name=model_name, index_type=index_type) subq = select(1).where(embedding_table.concept_id == Concept.concept_id) query = ( @@ -184,9 +245,10 @@ def get_concepts_without_embedding_count( self, session: Session, model_name: str, + index_type: IndexType, concept_filter: Optional[EmbeddingConceptFilter] = None, ) -> int: - embedding_table = self._get_embedding_table(session=session, model_name=model_name) + embedding_table = self.get_embedding_table(model_name=model_name, index_type=index_type) subq = select(1).where(embedding_table.concept_id == Concept.concept_id) query = ( @@ -200,9 +262,6 @@ def get_concepts_without_embedding_count( return session.scalar(query) # type: ignore - @abstractmethod - def _create_storage_table(self, engine: Engine, entry: ModelRegistry) -> Type[T]: - """Backend-specific logic to create the dynamic SQLAlchemy class.""" def register_model( self, @@ -211,76 +270,46 @@ def register_model( dimensions: int, *, index_type: IndexType, - metadata: Mapping[str, object] = {}, + metadata: Optional[Mapping[str, object]] = None, ) -> EmbeddingModelRecord: """ Shared template method for model registration. + + Parameters + ---------- + engine : Engine + SQLAlchemy engine for OMOP CDM database storage. + model_name : str + Registered name of the embedding model. + dimensions : int + Embedding dimensionality $D$ for this model. + index_type : IndexType + Storage index type used for this model's embeddings. + metadata : Optional[Mapping[str, object]] + Optional metadata persisted with the model registration. """ - self.initialise_store(engine) - with Session(engine, expire_on_commit=False) as session: - existing_row = session.scalar( - select(ModelRegistry).where(ModelRegistry.model_name == model_name) - ) - if existing_row is not None: - if existing_row.backend_type != self.backend_type: - raise ModelRegistrationConflictError( - f"Model '{model_name}' is already registered for backend " - f"'{existing_row.backend_type}', not '{self.backend_type}'.", - conflict_field="backend_type" - ) - if existing_row.dimensions != dimensions: - raise ModelRegistrationConflictError( - f"Model '{model_name}' is already registered with dimensions " - f"{existing_row.dimensions}, not {dimensions}.", - conflict_field="dimensions" - ) - if existing_row.index_type != index_type: - raise ModelRegistrationConflictError( - f"Model '{model_name}' is already registered with " - f"index_method='{existing_row.index_type}', not " - f"'{index_type}'. Reuse the existing model " - "configuration or register a new model name.", - conflict_field="index_type" - ) - if existing_row.details != metadata: - raise ModelRegistrationConflictError( - f"Model '{model_name}' is already registered with different " - f"metadata. Reuse the existing model name or choose a new one.", - conflict_field="metadata" - ) - return self._record_from_registry_row(existing_row) - - ensure_model_registry_schema(engine) - safe_name = self.safe_model_name(model_name) - storage_name = self.storage_name(safe_model_name=safe_name) - - new_entry = ModelRegistry( + model_record = self.embedding_model_registry.register_model( model_name=model_name, dimensions=dimensions, - storage_identifier=storage_name, - index_type=index_type, backend_type=self.backend_type, - details=metadata + index_type=index_type, + metadata=metadata or {}, ) - - with Session(engine, expire_on_commit=False) as session: - # Add logic here to check for existing records before adding - session.add(new_entry) - session.commit() - - dynamic_table = self._create_storage_table(engine, new_entry) - self.model_cache[model_name] = dynamic_table - return self._record_from_registry_row(new_entry) + # Local caching + self._cache_model_record(engine=engine, model_record=model_record) + return model_record @abstractmethod @require_registered_model def upsert_embeddings( self, - session: Session, model_name: str, - model_record: EmbeddingModelRecord, + index_type: IndexType, + *, + session: Session, concept_ids: Sequence[int], embeddings: ndarray, + _model_record: EmbeddingModelRecord, ) -> None: """ Insert or update vector embeddings for a collection of OMOP concept IDs. @@ -292,21 +321,18 @@ def upsert_embeddings( Parameters ---------- - session : sqlalchemy.orm.Session - The active database session used for transactional persistence and - model metadata updates. model_name : str - The unique identifier or name of the embedding model (e.g., - 'text-embedding-3-small'). - model_record : EmbeddingModelRecord - A record object containing metadata, dimensions, and configuration - specific to the embedding model being processed. + Registered name of the embedding model. + index_type : IndexType + Storage index type used for this model's embeddings. + session : sqlalchemy.orm.Session + SQLAlchemy session bound to the OMOP CDM database. concept_ids : Sequence[int] - A sequence of OMOP standard concept IDs corresponding to the - ordered rows in the embeddings array. + Concept IDs aligned with the rows of ``embeddings``. embeddings : numpy.ndarray - A 2D array of shape (n_concepts, n_dimensions) containing the - generated vector representations. + Embedding matrix of shape ``(n_concepts, D)``. + _model_record : EmbeddingModelRecord + Internal registered-model record injected by ``@require_registered_model``. Returns ------- @@ -323,10 +349,12 @@ def upsert_embeddings( @require_registered_model def get_embeddings_by_concept_ids( self, - session: Session, model_name: str, - model_record: EmbeddingModelRecord, + index_type: IndexType, + *, + session: Session, concept_ids: Sequence[int], + _model_record: EmbeddingModelRecord, ) -> Mapping[int, Sequence[float]]: """Fetch stored embedding vectors keyed by concept ID.""" @@ -335,101 +363,85 @@ def get_embeddings_by_concept_ids( @require_registered_model def get_nearest_concepts( self, - session: Session, model_name: str, + index_type: IndexType, + *, + session: Session, query_embedding: ndarray, metric_type: MetricType, - *, concept_filter: Optional[EmbeddingConceptFilter] = None, k: int = 10, + _model_record: EmbeddingModelRecord, ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: """ Return nearest stored concepts for the query embedding. Parameters ---------- - session : Session - SQLAlchemy session for any required relational access. model_name : str - Name of the embedding model used to create the embeddings. + Registered name of the embedding model. + index_type : IndexType + Storage index type used for this model's embeddings. + session : Session + SQLAlchemy session bound to the OMOP CDM database. query_embedding : ndarray - The embedding vector to search with. Expected shape is (q, dimension) - where q is the number of query vectors and dimension is the size of the embedding space for the model. + Query embedding matrix of shape ``(Q, D)``. metric_type : MetricType - The similarity or distance metric to use for nearest neighbor search. This should be compatible with the index type used by the model. + Similarity or distance metric for nearest-neighbor search. concept_filter : Optional[EmbeddingConceptFilter], optional - Optional constraints to apply when retrieving nearest concepts. This allows filtering by concept IDs, domains, vocabularies, or standard flags. By default, no additional filtering is applied. + Optional filter restricting which OMOP concepts are considered. k : int, optional - K nearest neighbors to return for each query vector. Default is 10. + Number of nearest matches to return per query. + _model_record : EmbeddingModelRecord + Internal registered-model record injected by ``@require_registered_model``. Returns ------- Tuple[Tuple[NearestConceptMatch, ...], ...] A tuple of tuples containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner tuple contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, k) where q is the number of query vectors and k is the number of nearest neighbors returned per query. """ - - def refresh_model_index(self, session: Session, model_name: str) -> None: - """ - Optional hook for derived index maintenance. - - PostgreSQL/pgvector backends may not need this. File-based or FAISS - backends may choose to rebuild or compact a search index here. - """ - return None - def has_any_embeddings(self, session: Session, model_name: str) -> bool: - embedding_table = self._get_embedding_table( - session=session, + def has_any_embeddings( + self, + session: Session, + model_name: str, + index_type: IndexType, + ) -> bool: + embedding_table = self.get_embedding_table( model_name=model_name, + index_type=index_type, ) return session.query(embedding_table.concept_id).limit(1).first() is not None - def _get_embedding_table( + def get_embedding_table( self, - session: Session, model_name: str, + index_type: IndexType, ) -> Type[T]: - embedding_table = self.model_cache.get(model_name) + """Return dynamically created ORM class for the embedding table associated with the given model name. + This relies on the `initialise_store` method having been called to populate the `embedding_table_cache`. + If the requested model name is not found in the cache, this indicates a logic error in the calling code (e.g., not initializing the store, or requesting a table for a model that hasn't been registered) rather than a user input error, so a ValueError is raised. + + Parameters + ---------- + model_name : str + The name of the embedding model whose associated embedding table class is to be retrieved. + Returns + ------- + Type[T] + The dynamically generated SQLAlchemy ORM class corresponding to the embedding table for the specified model. + + Raises + ------ + ValueError + If the model name is not found in the embedding table cache, indicating that the store was not properly initialized or the model has not been registered. + """ + storage_key = (model_name, self.backend_type, index_type) + embedding_table = self.embedding_table_cache.get(storage_key) if embedding_table is not None: return embedding_table - - bind = session.get_bind() - if bind is None: - raise RuntimeError("Session is not bound to an engine.") - assert isinstance(bind, Engine), f"Expected session bind to be an Engine. Got {type(bind)}" - self.initialise_store(bind) # Ensure tables are created and cache is populated - embedding_table = self.model_cache.get(model_name) - if embedding_table is None: - raise ValueError(f"Embedding model '{model_name}' not found in cache.") - return embedding_table - - @staticmethod - def _coerce_registry_metadata(value: object) -> Mapping[str, object]: - if isinstance(value, Mapping): - return dict(value) - return {} - - @staticmethod - def _record_from_registry_row( - row: ModelRegistry, - ) -> EmbeddingModelRecord: - return EmbeddingModelRecord( - model_name=row.model_name, - dimensions=row.dimensions, - backend_type=row.backend_type, - storage_identifier=row.storage_identifier, - index_type=row.index_type, - metadata=row.details, - ) - @staticmethod - def safe_model_name(model_name: str) -> str: - name = model_name.lower() - sanitized = re.sub(r"[^\w]+", "_", name) - sanitized = re.sub(r"_+", "_", sanitized).strip("_") - return sanitized - - def storage_name(self, safe_model_name: str) -> str: - return f"{self.backend_name.lower()}_{safe_model_name}" + else: + raise ValueError(f"Embedding table for model '{model_name}' with index type '{index_type.value}' and backend '{self.backend_type.value}' not found in cache. Ensure that the store is initialized and the model is registered before attempting to access the embedding table.") @staticmethod def validate_embeddings(embeddings: ndarray, dimensions: int): diff --git a/src/omop_emb/backends/factory.py b/src/omop_emb/backends/factory.py index 37169bc..577e7aa 100644 --- a/src/omop_emb/backends/factory.py +++ b/src/omop_emb/backends/factory.py @@ -4,17 +4,18 @@ from typing import Optional from .base import EmbeddingBackend -from .config import ( +from ..config import ( BackendType, ENV_OMOP_EMB_BACKEND, + parse_backend_type, ) -from .errors import ( +from omop_emb.utils.errors import ( EmbeddingBackendConfigurationError, EmbeddingBackendDependencyError, UnknownEmbeddingBackendError, ) -def normalize_backend_name(backend_name: Optional[str]) -> BackendType: +def normalize_backend_name(backend_name: Optional[str | BackendType]) -> BackendType: """ Normalize an embedding backend name from an explicit argument or env var. @@ -28,7 +29,7 @@ def normalize_backend_name(backend_name: Optional[str]) -> BackendType: raise AttributeError(f"No embedding backend specified. Provide an explicit backend_name or set the {ENV_OMOP_EMB_BACKEND} environment variable.") else: try: - backend_type = BackendType(backend_name) + backend_type = parse_backend_type(backend_name) except ValueError: raise UnknownEmbeddingBackendError( f"Unknown embedding backend {backend_name!r}. " @@ -38,9 +39,9 @@ def normalize_backend_name(backend_name: Optional[str]) -> BackendType: def get_embedding_backend( - backend_name: Optional[str] = None, - *, - faiss_base_dir: Optional[str] = None, + backend_name: Optional[str | BackendType] = None, + storage_base_dir: Optional[str] = None, + registry_db_name: Optional[str] = None, ) -> EmbeddingBackend: """ Construct an embedding backend implementation by name. @@ -59,7 +60,7 @@ def get_embedding_backend( "PGVector embedding backend requested but its dependencies are not " "available. Install the package using `pip install omop-emb[pgvector]`." ) from exc - return PGVectorEmbeddingBackend() + return PGVectorEmbeddingBackend(storage_base_dir=storage_base_dir, registry_db_name=registry_db_name) if resolved == BackendType.FAISS: try: @@ -69,7 +70,7 @@ def get_embedding_backend( "FAISS embedding backend requested but its dependencies are not " "available. Install the package with the FAISS extra using `pip install omop-emb[faiss]` or omop-emb[faiss-gpu]." ) from exc - return FaissEmbeddingBackend(base_dir=faiss_base_dir) + return FaissEmbeddingBackend(storage_base_dir=storage_base_dir, registry_db_name=registry_db_name) raise EmbeddingBackendConfigurationError( f"Backend factory reached an unexpected state for backend={resolved!r}." diff --git a/src/omop_emb/backends/faiss/faiss_backend.py b/src/omop_emb/backends/faiss/faiss_backend.py index decf1dc..e197255 100644 --- a/src/omop_emb/backends/faiss/faiss_backend.py +++ b/src/omop_emb/backends/faiss/faiss_backend.py @@ -1,73 +1,32 @@ from __future__ import annotations -import os from pathlib import Path -from typing import Any, Mapping, Optional, Sequence, Type, cast, Dict, Tuple +from typing import Mapping, Optional, Sequence, Type, Dict, Tuple import numpy as np from numpy import ndarray -from omop_emb.backends.config import BackendType -from sqlalchemy import Engine, insert, select +from sqlalchemy import Engine from sqlalchemy.orm import Session import logging -from dataclasses import dataclass, field -import h5py -from omop_alchemy.cdm.model.vocabulary import Concept -from omop_emb.backends.registry import ModelRegistry - -from ..errors import EmbeddingBackendConfigurationError from .faiss_sql import ( FAISSConceptIDEmbeddingRegistry, create_faiss_embedding_registry_table, add_concept_ids_to_faiss_registry, q_concept_ids_with_embeddings ) -from ..config import IndexType, MetricType, ENV_OMOP_EMB_FAISS_INDEX_DIR +from omop_emb.config import BackendType, IndexType, MetricType from .storage_manager import EmbeddingStorageManager from ..base import EmbeddingBackend, require_registered_model -from ..embedding_utils import ( - EmbeddingBackendCapabilities, +from omop_emb.utils.embedding_utils import ( EmbeddingConceptFilter, - EmbeddingModelRecord, NearestConceptMatch, get_similarity_from_distance ) +from omop_emb.model_registry import EmbeddingModelRecord logger = logging.getLogger(__name__) -@dataclass -class FaissMetadata: - """Strongly typed metadata required for the FAISS backend to operate.""" - index_file_path: str - metadata_bucket: dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> dict[str, Any]: - """Serializes to a flat dictionary for SQLAlchemy JSON storage.""" - # We flatten it so it looks clean in the database - result = self.metadata_bucket.copy() - result["index_file_path"] = self.index_file_path - return result - - @classmethod - def from_dict(cls, data: Mapping[str, Any]) -> "FaissMetadata": - """Safely reconstructs the object from the database JSON.""" - data_copy = dict(data) - - # Pop the required FAISS fields, raising a clear error if missing - try: - index_path = data_copy.pop("index_file_path") - except KeyError as e: - raise ValueError( - f"Corrupted FAISS metadata in registry. Missing required key: {e}" - ) - - # Everything else goes back into user_metadata - return cls( - index_file_path=str(index_path), - metadata_bucket=data_copy - ) - class FaissEmbeddingBackend(EmbeddingBackend[FAISSConceptIDEmbeddingRegistry]): """ @@ -91,69 +50,47 @@ class FaissEmbeddingBackend(EmbeddingBackend[FAISSConceptIDEmbeddingRegistry]): DEFAULT_FAISS_DIR = ".omop_emb/faiss" - def __init__(self, base_dir: Optional[str | os.PathLike[str]] = None): - super().__init__() - self.base_dir = Path( - base_dir - or os.getenv(ENV_OMOP_EMB_FAISS_INDEX_DIR) - or FaissEmbeddingBackend.DEFAULT_FAISS_DIR - ).expanduser() + def __init__( + self, + storage_base_dir: Optional[str | Path] = None, + registry_db_name: Optional[str] = None, + ): + super().__init__( + storage_base_dir=storage_base_dir, + registry_db_name=registry_db_name, + ) self.embedding_storage_managers: Dict[str, EmbeddingStorageManager] = {} @property def backend_type(self) -> BackendType: return BackendType.FAISS - @property - def capabilities(self) -> EmbeddingBackendCapabilities: - return EmbeddingBackendCapabilities( - stores_embeddings=True, - supports_incremental_upsert=True, - supports_nearest_neighbor_search=True, - supports_server_side_similarity=False, - supports_filter_by_concept_ids=True, - supports_filter_by_domain=True, - supports_filter_by_vocabulary=True, - supports_filter_by_standard_flag=True, - requires_explicit_index_refresh=False, - ) - - def initialise_store(self, engine) -> None: - self.base_dir.mkdir(parents=True, exist_ok=True) - return super().initialise_store(engine) - - def _create_storage_table(self, engine: Engine, entry: ModelRegistry) -> Type[FAISSConceptIDEmbeddingRegistry]: - return create_faiss_embedding_registry_table(engine=engine, model_registry_entry=entry) + def _create_storage_table(self, engine: Engine, model_record: EmbeddingModelRecord) -> Type[FAISSConceptIDEmbeddingRegistry]: + return create_faiss_embedding_registry_table(engine=engine, model_record=model_record) def get_safe_model_dir(self, model_name: str) -> Path: - return self.base_dir / self.safe_model_name(model_name) + return self.storage_base_dir / self.embedding_model_registry.safe_model_name(model_name) def get_storage_manager( self, - model_name: str, - dimensions: int, - index_type: IndexType, + model_record: EmbeddingModelRecord, ) -> EmbeddingStorageManager: self.register_storage_manager( - model_name=model_name, - dimensions=dimensions, - index_type=index_type, + model_record=model_record, ) - return self.embedding_storage_managers[model_name] + return self.embedding_storage_managers[model_record.model_name] def register_storage_manager( self, - model_name: str, - dimensions: int, - index_type: IndexType, + model_record: EmbeddingModelRecord, ) -> None: """Registers a storage manager for the given model if not already registered. This ensures that the necessary on-disk structures are in place for the model's embeddings and index.""" - if model_name not in self.embedding_storage_managers: - logger.info(f"Registering new storage manager for model '{model_name}' with dimensions={dimensions}, index_type={index_type}") - self.embedding_storage_managers[model_name] = EmbeddingStorageManager( - file_dir=self.get_safe_model_dir(model_name), - dimensions=dimensions, + if model_record.model_name not in self.embedding_storage_managers: + logger.info(f"Registering new storage manager for model '{model_record.model_name}' with dimensions={model_record.dimensions}, index_type={model_record.index_type}") + self.embedding_storage_managers[model_record.model_name] = EmbeddingStorageManager( + file_dir=self.get_safe_model_dir(model_record.model_name), + dimensions=model_record.dimensions, backend_type=self.backend_type, ) @@ -168,33 +105,28 @@ def register_model( ) -> EmbeddingModelRecord: # Create the storage for the model on disk - safe_name = self.safe_model_name(model_name) - model_dir = self.base_dir / safe_name + safe_name = self.embedding_model_registry.safe_model_name(model_name) + model_dir = self.storage_base_dir / safe_name model_dir.mkdir(parents=False, exist_ok=True) logger.info(f"Created model directory: {model_dir}") - # NOTE: not sure if still necessary as the manager takes care of it all. - #faiss_metadata = FaissMetadata( - # index_file_path=str(self._index_path(model_name)), - # metadata_bucket=dict(metadata), - #) return super().register_model( engine=engine, model_name=model_name, dimensions=dimensions, index_type=index_type, - #metadata=faiss_metadata.to_dict(), metadata=metadata ) @require_registered_model def upsert_embeddings( self, - session: Session, model_name: str, - model_record: EmbeddingModelRecord, + index_type: IndexType, + session: Session, concept_ids: Sequence[int], embeddings: ndarray, + _model_record: EmbeddingModelRecord, metric_type: Optional[MetricType] = None, ) -> None: """ @@ -206,26 +138,20 @@ def upsert_embeddings( Parameters ---------- - session : sqlalchemy.orm.Session - The active database session used for transactional persistence and - model metadata updates. model_name : str - The unique identifier or name of the embedding model (e.g., - 'text-embedding-3-small'). - model_record : EmbeddingModelRecord - A record object containing metadata, dimensions, and configuration - specific to the embedding model being processed. + Registered name of the embedding model. + index_type : IndexType + Storage index type used for this model's embeddings. + session : sqlalchemy.orm.Session + SQLAlchemy session bound to the OMOP CDM database. concept_ids : Sequence[int] - A sequence of OMOP standard concept IDs corresponding to the - ordered rows in the embeddings array. + Concept IDs aligned with the rows of ``embeddings``. embeddings : numpy.ndarray - A 2D array of shape (n_concepts, n_dimensions) containing the - generated vector representations. + Embedding matrix of shape ``(n_concepts, D)``. + _model_record : EmbeddingModelRecord + Internal registered-model record injected by ``@require_registered_model``. metric_type : Optional[MetricType] - The similarity metric type (e.g., COSINE, EUCLIDEAN) that should be - associated with the stored embeddings for accurate nearest neighbor - search behavior. If None, the index will not be built and the raw - embeddings are only stored in the HDF5 file for later use. + Optional metric used when creating/updating the FAISS index. Returns ------- @@ -242,7 +168,7 @@ def upsert_embeddings( self.validate_embeddings_and_concept_ids( concept_ids=concept_id_tuple, embeddings=embeddings, - dimensions=model_record.dimensions, + dimensions=_model_record.dimensions, ) # Try to add first before we damage the on-disk stuff @@ -250,7 +176,10 @@ def upsert_embeddings( add_concept_ids_to_faiss_registry( concept_ids=concept_id_tuple, session=session, - registered_table=self._get_embedding_table(session, model_name), + registered_table=self.get_embedding_table( + model_name=model_name, + index_type=index_type, + ), ) except Exception as e: session.rollback() @@ -258,43 +187,36 @@ def upsert_embeddings( f"Failed to add concept IDs to FAISS registry for model '{model_name}'. This may be due to a mismatch between the provided concept_ids and the existing entries in the database, or a database constraint violation. Original error: {str(e)}" ) from e - storage_manager = self.get_storage_manager( - model_name=model_name, - dimensions=model_record.dimensions, - index_type=model_record.index_type, - ) + storage_manager = self.get_storage_manager(_model_record) storage_manager.append( concept_ids=np.array(concept_id_tuple, dtype=np.int64), embeddings=embeddings, - index_type=model_record.index_type, + index_type=_model_record.index_type, metric_type=metric_type ) @require_registered_model def get_nearest_concepts( self, - session: Session, model_name: str, - model_record: EmbeddingModelRecord, + index_type: IndexType, + session: Session, query_embeddings: np.ndarray, metric_type: MetricType, *, concept_filter: Optional[EmbeddingConceptFilter] = None, k: int = 10, + _model_record: EmbeddingModelRecord, ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: - storage_manager = self.get_storage_manager( - model_name=model_name, - dimensions=model_record.dimensions, - index_type=model_record.index_type, - ) + storage_manager = self.get_storage_manager(_model_record) if storage_manager.get_count() == 0: return () - self.validate_embeddings(embeddings=query_embeddings, dimensions=model_record.dimensions) + self.validate_embeddings(embeddings=query_embeddings, dimensions=_model_record.dimensions) q_permitted_concept_ids = q_concept_ids_with_embeddings( - embedding_table=self._get_embedding_table(session, model_name), + embedding_table=self.get_embedding_table(model_name=model_name, index_type=index_type), concept_filter=concept_filter, limit=None ) @@ -312,7 +234,7 @@ def get_nearest_concepts( distances, concept_ids = storage_manager.search( query_vector=query_embeddings, metric_type=metric_type, - index_type=model_record.index_type, + index_type=_model_record.index_type, k=k, subset_concept_ids=permitted_concept_ids ) @@ -336,18 +258,15 @@ def get_nearest_concepts( )) matches.append(tuple(matches_per_query)) return tuple(matches) - - @require_registered_model - def refresh_model_index(self, session: Session, model_name: str, model_record: EmbeddingModelRecord) -> None: - raise NotImplementedError("Explicit index refresh is not required for the FAISS backend as it updates the index on each upsert. This method is a no-op.") @require_registered_model def get_embeddings_by_concept_ids( - self, - session: Session, - model_name: str, - model_record: EmbeddingModelRecord, - concept_ids: Sequence[int] + self, + model_name: str, + index_type: IndexType, + session: Session, + concept_ids: Sequence[int], + _model_record: EmbeddingModelRecord ) -> Mapping[int, Sequence[float]]: concept_id_tuple = tuple(concept_ids) if not concept_id_tuple: @@ -355,46 +274,6 @@ def get_embeddings_by_concept_ids( concept_ids_np = np.array(concept_id_tuple, dtype=np.int64) - storage_manager = self.get_storage_manager( - model_name=model_name, - dimensions=model_record.dimensions, - index_type=model_record.index_type, - ) + storage_manager = self.get_storage_manager(_model_record) return storage_manager.get_embeddings_by_concept_ids(concept_ids=concept_ids_np) - - - def _get_filtered_concept_ids( - self, - session: Session, - concept_filter: EmbeddingConceptFilter, - ) -> tuple[int, ...]: - query = select(Concept.concept_id) - if concept_filter is not None: - query = concept_filter.apply(query) - return tuple(int(row.concept_id) for row in session.execute(query)) - - @staticmethod - def _normalize_matrix(matrix: np.ndarray) -> np.ndarray: - norms = np.linalg.norm(matrix, axis=1, keepdims=True) - norms[norms == 0] = 1.0 - return matrix / norms - - @staticmethod - def _filter_is_empty(concept_filter: EmbeddingConceptFilter) -> bool: - return ( - concept_filter.concept_ids is None - and concept_filter.domains is None - and concept_filter.vocabularies is None - and not concept_filter.require_standard - ) - - @staticmethod - def _create_faiss_index(*, faiss, index_type: str, dimensions: int): - index_type = str(getattr(index_type, "value", index_type)) - if index_type in {"IndexFlatIP", "flat", "flatip"}: - return faiss.IndexFlatIP(dimensions) - raise ValueError( - f"Unsupported FAISS index_type={index_type!r} in first-pass backend. " - "Currently supported: IndexFlatIP." - ) \ No newline at end of file diff --git a/src/omop_emb/backends/faiss/faiss_sql.py b/src/omop_emb/backends/faiss/faiss_sql.py index 1d0d6c6..a07e83e 100644 --- a/src/omop_emb/backends/faiss/faiss_sql.py +++ b/src/omop_emb/backends/faiss/faiss_sql.py @@ -9,59 +9,54 @@ from orm_loader.helpers import Base from omop_alchemy.cdm.model.vocabulary import Concept +from omop_emb.model_registry import EmbeddingModelRecord from ..base import ConceptIDEmbeddingBase -from ..config import BackendType -from ..registry import ModelRegistry -from ..embedding_utils import EmbeddingConceptFilter +from omop_emb.utils.embedding_utils import EmbeddingConceptFilter logger = logging.getLogger(__name__) +_FAISS_REGISTRY_TABLE_CACHE: dict[str, type["FAISSConceptIDEmbeddingRegistry"]] = {} + + class FAISSConceptIDEmbeddingRegistry(ConceptIDEmbeddingBase, Base): """Registry table to track which concept_ids are present in the FAISS/H5 storage for a specific model.""" __abstract__ = True def create_faiss_embedding_registry_table( engine: Engine, - model_registry_entry: ModelRegistry, + model_record: EmbeddingModelRecord, ) -> Type[FAISSConceptIDEmbeddingRegistry]: """ Creates a dynamic SQLAlchemy ORM class that tracks which concept_ids are present in the FAISS/H5 storage for a specific model. """ - tablename = model_registry_entry.storage_identifier + tablename = model_record.storage_identifier + + cached_table = _FAISS_REGISTRY_TABLE_CACHE.get(tablename) + if cached_table is not None: + Base.metadata.create_all(engine, tables=[cached_table.__table__]) # type: ignore[arg-type] + return cached_table mapping_table = type( - tablename, + f"FAISSConceptIDEmbeddingRegistry_{tablename}", (FAISSConceptIDEmbeddingRegistry, ), { "__tablename__": tablename, "__table_args__": {"extend_existing": True}, + "__module__": __name__, }, ) Base.metadata.create_all(engine, tables=[mapping_table.__table__]) # type: ignore[arg-type] + _FAISS_REGISTRY_TABLE_CACHE[tablename] = mapping_table logger.debug( - f"Initialized {FAISSConceptIDEmbeddingRegistry.__name__} table for model '{model_registry_entry.model_name}'", + f"Initialized {FAISSConceptIDEmbeddingRegistry.__name__} table for model '{model_record.model_name}'", ) return mapping_table -def initialise_faiss_mapping_tables( - engine: Engine, - model_cache: dict[str, Type[FAISSConceptIDEmbeddingRegistry]], -) -> None: - with Session(engine, expire_on_commit=False) as session: - existing_models = session.scalars(select(ModelRegistry).where(ModelRegistry.backend_type == BackendType.FAISS.value)).all() - session.commit() - - for model_entry in existing_models: - if model_entry.model_name not in model_cache: - # Create the class and cache it - dynamic_table = create_faiss_embedding_registry_table(engine=engine, model_registry_entry=model_entry) - model_cache[model_entry.model_name] = dynamic_table - def add_concept_ids_to_faiss_registry( concept_ids: tuple[int, ...], session: Session, diff --git a/src/omop_emb/backends/faiss/index_manager.py b/src/omop_emb/backends/faiss/index_manager.py index 52e8b9c..bd33e50 100644 --- a/src/omop_emb/backends/faiss/index_manager.py +++ b/src/omop_emb/backends/faiss/index_manager.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Optional, Generator, Tuple -from omop_emb.backends.config import IndexType, MetricType +from omop_emb.config import IndexType, MetricType import logging logger = logging.getLogger(__name__) diff --git a/src/omop_emb/backends/faiss/storage_manager.py b/src/omop_emb/backends/faiss/storage_manager.py index 44bf423..fca9331 100644 --- a/src/omop_emb/backends/faiss/storage_manager.py +++ b/src/omop_emb/backends/faiss/storage_manager.py @@ -11,7 +11,7 @@ import logging logger = logging.getLogger(__name__) -from omop_emb.backends.config import BackendType, IndexType, MetricType +from omop_emb.config import BackendType, IndexType, MetricType from .index_manager import FlatIndexManager, BaseIndexManager class EmbeddingStorageManager: diff --git a/src/omop_emb/backends/pgvector/pgvector_backend.py b/src/omop_emb/backends/pgvector/pgvector_backend.py index 92050ad..548cb11 100644 --- a/src/omop_emb/backends/pgvector/pgvector_backend.py +++ b/src/omop_emb/backends/pgvector/pgvector_backend.py @@ -14,15 +14,13 @@ create_pg_embedding_table, add_embeddings_to_registered_table, ) -from ..registry import ModelRegistry -from ..config import BackendType, MetricType +from omop_emb.config import BackendType, MetricType, IndexType from ..base import EmbeddingBackend, require_registered_model -from ..embedding_utils import ( - EmbeddingBackendCapabilities, +from omop_emb.utils.embedding_utils import ( EmbeddingConceptFilter, - EmbeddingModelRecord, NearestConceptMatch, ) +from omop_emb.model_registry import EmbeddingModelRecord class PGVectorEmbeddingBackend(EmbeddingBackend[PGVectorConceptIDEmbeddingTable]): """ @@ -38,49 +36,54 @@ class PGVectorEmbeddingBackend(EmbeddingBackend[PGVectorConceptIDEmbeddingTable] def backend_type(self) -> BackendType: return BackendType.PGVECTOR - @property - def capabilities(self) -> EmbeddingBackendCapabilities: - return EmbeddingBackendCapabilities( - stores_embeddings=True, - supports_incremental_upsert=True, - supports_nearest_neighbor_search=True, - supports_server_side_similarity=True, - supports_filter_by_concept_ids=True, - supports_filter_by_domain=True, - supports_filter_by_vocabulary=True, - supports_filter_by_standard_flag=True, - requires_explicit_index_refresh=False, - ) + def _create_storage_table(self, engine: Engine, model_record: EmbeddingModelRecord) -> Type[PGVectorConceptIDEmbeddingTable]: + return create_pg_embedding_table(engine=engine, model_record=model_record) - def _create_storage_table(self, engine: Engine, entry: ModelRegistry) -> Type[PGVectorConceptIDEmbeddingTable]: - return create_pg_embedding_table(engine=engine, model_registry_entry=entry) - - def initialise_store(self, engine: Engine) -> None: + def pre_initialise_store(self, engine: Engine) -> None: with Session(engine, expire_on_commit=False) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vector CASCADE;")) session.commit() - return super().initialise_store(engine) @require_registered_model def upsert_embeddings( self, - session: Session, model_name: str, - model_record: EmbeddingModelRecord, + index_type: IndexType, + *, + session: Session, concept_ids: Sequence[int], embeddings: ndarray, + _model_record: EmbeddingModelRecord, ) -> None: + """ + Insert or update embeddings for OMOP concept IDs. + + Parameters + ---------- + model_name : str + Registered name of the embedding model. + index_type : IndexType + Storage index type used for this model's embeddings. + session : Session + SQLAlchemy session bound to the OMOP CDM database. + concept_ids : Sequence[int] + Concept IDs aligned with the rows of ``embeddings``. + embeddings : ndarray + Embedding matrix of shape ``(n_concepts, D)``. + _model_record : EmbeddingModelRecord + Internal registered-model record injected by ``@require_registered_model``. + """ concept_id_tuple = tuple(concept_ids) self.validate_embeddings_and_concept_ids( concept_ids=concept_id_tuple, embeddings=embeddings, - dimensions=model_record.dimensions, + dimensions=_model_record.dimensions, ) - table = self._get_embedding_table( - session=session, + table = self.get_embedding_table( model_name=model_name, + index_type=index_type, ) try: @@ -99,18 +102,37 @@ def upsert_embeddings( @require_registered_model def get_embeddings_by_concept_ids( self, - session: Session, model_name: str, - model_record: EmbeddingModelRecord, + index_type: IndexType, + *, + session: Session, concept_ids: Sequence[int], + _model_record: EmbeddingModelRecord, ) -> Mapping[int, Sequence[float]]: + """ + Return embeddings for the requested concept IDs. + + Parameters + ---------- + model_name : str + Registered name of the embedding model. + index_type : IndexType + Storage index type used for this model's embeddings. + session : Session + SQLAlchemy session bound to the OMOP CDM database. + concept_ids : Sequence[int] + Concept IDs to fetch from backend storage. + _model_record : EmbeddingModelRecord + Internal registered-model record injected by ``@require_registered_model``. + """ + concept_id_tuple = tuple(concept_ids) if not concept_id_tuple: return {} - embedding_table = self._get_embedding_table( - session=session, + embedding_table = self.get_embedding_table( model_name=model_name, + index_type=index_type, ) query = q_embedding_vectors_by_concept_ids( embedding_table=embedding_table, @@ -132,21 +154,43 @@ def get_embeddings_by_concept_ids( @require_registered_model def get_nearest_concepts( self, - session: Session, model_name: str, - model_record: EmbeddingModelRecord, - query_embeddings: ndarray, + index_type: IndexType, *, + session: Session, + query_embeddings: ndarray, metric_type: MetricType, concept_filter: Optional[EmbeddingConceptFilter] = None, k: int = 10, - ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: - - embedding_table = self._get_embedding_table( - session=session, + _model_record: EmbeddingModelRecord, + ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: + """ + Return nearest stored concepts for query embeddings. + + Parameters + ---------- + model_name : str + Registered name of the embedding model. + index_type : IndexType + Storage index type used for this model's embeddings. + session : Session + SQLAlchemy session bound to the OMOP CDM database. + query_embeddings : ndarray + Query embedding matrix of shape ``(Q, D)``. + metric_type : MetricType + Similarity or distance metric for nearest-neighbor search. + concept_filter : Optional[EmbeddingConceptFilter], optional + Optional filter restricting which OMOP concepts are considered. + k : int, optional + Number of nearest matches to return per query. + _model_record : EmbeddingModelRecord + Internal registered-model record injected by ``@require_registered_model``. + """ + embedding_table = self.get_embedding_table( model_name=model_name, + index_type=index_type, ) - self.validate_embeddings(embeddings=query_embeddings, dimensions=model_record.dimensions) + self.validate_embeddings(embeddings=query_embeddings, dimensions=_model_record.dimensions) query_list = query_embeddings.tolist() query = q_embedding_nearest_concepts( @@ -173,17 +217,3 @@ def get_nearest_concepts( return tuple(tuple(matches) for matches in results) - def refresh_model_index(self, session: Session, model_name: str) -> None: - # pgvector indexes update as rows are written, so no explicit refresh is - # required for the current PostgreSQL backend. - return None - - def _record_from_registry_row(self, row: ModelRegistry) -> EmbeddingModelRecord: - return EmbeddingModelRecord( - model_name=row.model_name, - dimensions=row.dimensions, - backend_type=row.backend_type, - storage_identifier=row.storage_identifier, - index_type=row.index_type, - metadata=self._coerce_registry_metadata(row.details), - ) diff --git a/src/omop_emb/backends/pgvector/pgvector_sql.py b/src/omop_emb/backends/pgvector/pgvector_sql.py index 855c7b1..b89fb1e 100644 --- a/src/omop_emb/backends/pgvector/pgvector_sql.py +++ b/src/omop_emb/backends/pgvector/pgvector_sql.py @@ -11,14 +11,17 @@ import logging from numpy import ndarray -from ..config import BackendType, IndexType, MetricType -from ..registry import ModelRegistry +from omop_emb.config import BackendType, IndexType, MetricType +from omop_emb.model_registry import EmbeddingModelRecord from ..base import ConceptIDEmbeddingBase -from ..embedding_utils import EmbeddingConceptFilter, get_similarity_from_distance +from omop_emb.utils.embedding_utils import EmbeddingConceptFilter, get_similarity_from_distance logger = logging.getLogger(__name__) +_PGVECTOR_TABLE_CACHE: dict[str, type["PGVectorConceptIDEmbeddingTable"]] = {} + + class PGVectorConceptIDEmbeddingTable(ConceptIDEmbeddingBase, Base): """Abstract base for Postgres linter support.""" __abstract__ = True @@ -26,41 +29,35 @@ class PGVectorConceptIDEmbeddingTable(ConceptIDEmbeddingBase, Base): def create_pg_embedding_table( engine: Engine, - model_registry_entry: ModelRegistry, + model_record: EmbeddingModelRecord, ) -> Type[PGVectorConceptIDEmbeddingTable]: # Note the specific Type hint here - tablename = model_registry_entry.storage_identifier - dimensions = model_registry_entry.dimensions + tablename = model_record.storage_identifier + dimensions = model_record.dimensions + + cached_table = _PGVECTOR_TABLE_CACHE.get(tablename) + if cached_table is not None: + Base.metadata.create_all(engine, tables=[cached_table.__table__]) # type: ignore[arg-type] + create_indices_for_embedding_table(engine, model_record) + return cached_table table = type( - tablename, - (PGVectorConceptIDEmbeddingTable,), # It already inherits from Base and ConceptIDBase + f"PGVectorConceptIDEmbeddingTable_{tablename}", + (PGVectorConceptIDEmbeddingTable,), { "__tablename__": tablename, - # We override the attribute with the specific dimension-aware column "embedding": mapped_column(Vector(dimensions), nullable=False, index=False), "__table_args__": {"extend_existing": True}, + "__module__": __name__, } ) Base.metadata.create_all(engine, tables=[table.__table__]) # type: ignore[arg-type] + _PGVECTOR_TABLE_CACHE[tablename] = table - with engine.connect() as conn: - inspector = inspect(conn) - existing_indexes = [idx['name'] for idx in inspector.get_indexes(model_registry_entry.storage_identifier)] - - create_index_sql = _create_index_sql( - model_registry_entry.storage_identifier, - IndexType(model_registry_entry.index_type), - ) - if ( - _index_from_storage_identifier(model_registry_entry.storage_identifier) not in existing_indexes and - create_index_sql is not None - ): - conn.execute(text(create_index_sql)) - conn.commit() + create_indices_for_embedding_table(engine, model_record) logger.debug( - f"Initialized {PGVectorConceptIDEmbeddingTable.__name__} table for model '{model_registry_entry.model_name}' with dimensions {model_registry_entry.dimensions} using index_method={model_registry_entry.index_type}", + f"Initialized {PGVectorConceptIDEmbeddingTable.__name__} table for model '{model_record.model_name}' with dimensions {model_record.dimensions} using index_method={model_record.index_type}", ) return table @@ -99,7 +96,21 @@ def add_embeddings_to_registered_table( session.commit() -def _create_index_sql(table_name: str, index_type: IndexType) -> Optional[str]: +def create_indices_for_embedding_table(engine: Engine, model_record: EmbeddingModelRecord) -> None: + with engine.connect() as conn: + inspector = inspect(conn) + existing_indexes = [idx['name'] for idx in inspector.get_indexes(model_record.storage_identifier)] + + if _index_from_storage_identifier(model_record.storage_identifier) not in existing_indexes: + query = q_create_index_sql( + model_record.storage_identifier, + model_record.index_type, + ) + if query is not None: + conn.execute(text(query)) + conn.commit() + +def q_create_index_sql(table_name: str, index_type: IndexType) -> Optional[str]: if index_type == IndexType.FLAT: pass # No additional index needed for flat, as the vector column itself can be used for sequential scan else: diff --git a/src/omop_emb/cli.py b/src/omop_emb/cli.py index 28957dd..2f4c4e9 100644 --- a/src/omop_emb/cli.py +++ b/src/omop_emb/cli.py @@ -3,24 +3,97 @@ import sqlalchemy as sa from sqlalchemy.orm import Session -from orm_loader.helpers import get_logger, configure_logging, create_db -from omop_alchemy.cdm.model.vocabulary import Concept +from orm_loader.helpers import get_logger, create_db from typing import Annotated, Optional +from pathlib import Path +import csv +import json import os +import logging from dotenv import load_dotenv from tqdm import tqdm import typer -from omop_emb.backends import ( - EmbeddingConceptFilter, -) +from omop_emb.utils.embedding_utils import EmbeddingConceptFilter from omop_emb.interface import EmbeddingInterface -from omop_emb.backends.config import BackendType, IndexType +from omop_emb.config import IndexType, BackendType app = typer.Typer() logger = get_logger(__name__) +SNAPSHOT_MANIFEST_NAME = "manifest.json" + + +def configure_logging_level(verbosity: int) -> None: + """Configure global logging based on CLI verbosity flags.""" + level_map = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} + log_level = level_map.get(min(verbosity, 2), logging.DEBUG) + + logging.basicConfig( + level=log_level, + format="%(asctime)s | %(name)s | %(levelname)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + ) + + +def _load_legacy_rows(engine: sa.Engine, legacy_table: str) -> list[dict[str, object]]: + inspector = sa.inspect(engine) + if not inspector.has_table(legacy_table): + raise RuntimeError(f"Legacy table '{legacy_table}' does not exist in source database.") + + query = sa.text(f'SELECT * FROM "{legacy_table}"') + with engine.connect() as conn: + rows = conn.execute(query).mappings().all() + return [dict(row) for row in rows] + + +def _legacy_row_fields(row: dict[str, object]) -> tuple[str, int, IndexType, str, dict[str, object]]: + model_name_raw = row.get("model_name") + dimensions_raw = row.get("dimensions") + if model_name_raw is None or dimensions_raw is None: + raise ValueError("Legacy row missing required fields 'model_name' and/or 'dimensions'.") + + index_raw = row.get("index_type") or row.get("index_method") or IndexType.FLAT.value + index_type = IndexType(str(index_raw)) + + storage_identifier_raw = row.get("storage_identifier") or row.get("table_name") + if storage_identifier_raw is None: + raise ValueError("Legacy row missing required storage field ('storage_identifier' or 'table_name').") + storage_identifier = str(storage_identifier_raw) + + details_raw = row.get("details") or row.get("metadata") or {} + if isinstance(details_raw, dict): + metadata = dict(details_raw) + else: + metadata = {} + + model_name = str(model_name_raw) + dimensions = int(str(dimensions_raw)) + + return model_name, dimensions, index_type, storage_identifier, metadata + + +def _resolve_engine() -> sa.Engine: + engine_string = os.getenv('OMOP_DATABASE_URL') + if engine_string is None: + raise RuntimeError("OMOP_DATABASE_URL environment variable not set. Please set it in your .env file to point to your database.") + + engine = sa.create_engine(engine_string, future=True, echo=False) + assert engine.dialect.name == "postgresql", "Only PostgreSQL databases are supported for embedding storage with the current backends. Please check your `OMOP_DATABASE_URL` environment variable and ensure it points to a PostgreSQL database." + return engine + + +def _build_pgvector_interface(storage_base_dir: Optional[str]) -> EmbeddingInterface: + interface = EmbeddingInterface.from_backend_name( + backend_name=BackendType.PGVECTOR, + storage_base_dir=storage_base_dir, + ) + if interface.backend.backend_type != BackendType.PGVECTOR: + raise RuntimeError("Resolved embedding backend is not pgvector. Set --storage-base-dir and/or OMOP_EMB_BACKEND appropriately.") + return interface + @app.command() def add_embeddings( @@ -45,10 +118,9 @@ def add_embeddings( "--backend", help="Embedding backend to use. Can be replaced by the `OMOP_EMB_BACKEND` environment variable." )] = None, - - faiss_base_dir: Annotated[Optional[str], typer.Option( - "--faiss-base-dir", - help="Optional base directory for FAISS backend storage." + storage_base_dir: Annotated[Optional[str], typer.Option( + "--storage-base-dir", + help="Optional base directory for embedding backend storage. Reverts to `OMOP_EMB_BASE_STORAGE_DIR` if not provided, or defaults to ./.omop_emb in the current working directory. Paths with `~` are expanded." )] = None, standard_only: Annotated[bool, typer.Option( "--standard-only", @@ -58,23 +130,26 @@ def add_embeddings( "--vocabulary", help="Optional vocabulary filter. Repeat the option to embed concepts only from specific OMOP vocabularies." )] = None, + domains: Annotated[Optional[list[str]], typer.Option( + "--domain", + help="Optional domain filter. Repeat the option to embed concepts only from specific OMOP domains." + )] = None, num_embeddings: Annotated[Optional[int], typer.Option( "--num-embeddings", "-n", help="If set, limits the number of concepts for which embeddings are generated. Useful for testing and development to speed up the embedding generation step.")] = None, + verbosity: Annotated[int, typer.Option( + "--verbose", "-v", count=True, + help="Increase verbosity (up to two levels)" + )] = 0, ): - configure_logging() + configure_logging_level(verbosity) load_dotenv() - engine_string = os.getenv('OMOP_DATABASE_URL') - if engine_string is None: - raise RuntimeError("OMOP_DATABASE_URL environment variable not set. Please set it in your .env file to point to your database.") - - engine = sa.create_engine(engine_string, future=True, echo=False) - assert engine.dialect.name == "postgresql", "Only PostgreSQL databases are supported for embedding storage with the current backends. Please check your `OMOP_DATABASE_URL` environment variable and ensure it points to a PostgreSQL database." + engine = _resolve_engine() interface = EmbeddingInterface.from_backend_name( backend_name=backend_name, - faiss_base_dir=faiss_base_dir, + storage_base_dir=storage_base_dir, embedding_client=LLMClient( model=model, api_base=api_base, @@ -87,32 +162,31 @@ def add_embeddings( concept_filter = EmbeddingConceptFilter( require_standard=standard_only, + domains=tuple(domains) if domains else None, vocabularies=tuple(vocabularies) if vocabularies else None, ) # Ensure OMOP metadata tables exist, then initialize the embedding store. create_db(engine) - interface.initialise_store(engine) + interface.setup_and_register_model( + engine=engine, + model_name=model, + dimensions=embedding_dim, + index_type=index_type, + ) with Session(engine) as reader, Session(engine) as writer: - interface.ensure_model_registered( - engine=engine, - session=reader, - model_name=model, - index_type=index_type, - dimensions=embedding_dim - ) - total_concepts = num_embeddings or interface.get_concepts_without_embedding_count( session=reader, model_name=model, concept_filter=concept_filter, + index_type=index_type ) - concepts_without_embedding = interface.get_concepts_without_embedding_query( - session=reader, + concepts_without_embedding = interface.q_get_concepts_without_embedding( model_name=model, concept_filter=concept_filter, limit=num_embeddings, + index_type=index_type ) logger.info(f"Total concepts to process: {total_concepts}") @@ -128,6 +202,7 @@ def add_embeddings( concept_ids=tuple(batch_concepts.keys()), concept_texts=tuple(batch_concepts.values()), batch_size=batch_size, + index_type=index_type ) pbar.update(len(batch_concepts)) @@ -135,5 +210,288 @@ def add_embeddings( logger.info("Completed embedding generation and storage.") +@app.command() +def export_pgvector( + output_dir: Annotated[str, typer.Option( + "--output-dir", "-o", + help="Directory where pgvector snapshot files (CSV + manifest) are written.", + )], + storage_base_dir: Annotated[Optional[str], typer.Option( + "--storage-base-dir", + help="Optional base directory for omop-emb metadata registry (metadata.db). Reverts to `OMOP_EMB_BASE_STORAGE_DIR` if not provided, or defaults to ./.omop_emb in the current working directory. Paths with `~` are expanded.", + )] = None, + model: Annotated[Optional[list[str]], typer.Option( + "--model", "-m", + help="Optional model-name filter. Repeat to export specific models only.", + )] = None, + index_type: Annotated[Optional[IndexType], typer.Option( + "--index-type", + help="Optional index-type filter.", + )] = None, + verbosity: Annotated[int, typer.Option( + "--verbose", "-v", count=True, + help="Increase verbosity (up to two levels)" + )] = 0, +): + """Export pgvector embedding tables to file for checkpoint/restore workflows.""" + configure_logging_level(verbosity) + load_dotenv() + + engine = _resolve_engine() + interface = _build_pgvector_interface(storage_base_dir=storage_base_dir) + interface.initialise_store(engine) + + records = interface.backend.embedding_model_registry.get_registered_models_from_db( + backend_type=BackendType.PGVECTOR + ) + if records is None: + raise RuntimeError("No pgvector models found in local registry metadata. Nothing to export.") + + model_filter = set(model) if model else None + selected_records = [ + record for record in records + if (model_filter is None or record.model_name in model_filter) + and (index_type is None or record.index_type == index_type) + ] + if not selected_records: + raise RuntimeError("No pgvector models matched the provided filters.") + + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + + manifest = { + "format_version": 1, + "backend": BackendType.PGVECTOR.value, + "tables": [], + } + + with engine.connect() as conn: + for record in selected_records: + table_name = record.storage_identifier + quoted_table_name = f'"{table_name}"' + row_count = conn.execute(sa.text(f"SELECT COUNT(*) FROM {quoted_table_name}")) + n_rows = int(row_count.scalar_one()) + + csv_filename = f"{table_name}.csv" + csv_path = output_path / csv_filename + + query = sa.text( + f"SELECT concept_id, embedding::text AS embedding FROM {quoted_table_name} ORDER BY concept_id" + ) + result = conn.execution_options(stream_results=True).execute(query) + + with csv_path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.writer(handle) + writer.writerow(["concept_id", "embedding"]) + for row in result: + writer.writerow([int(row.concept_id), str(row.embedding)]) + + logger.info(f"Exported {n_rows} rows from {table_name} to {csv_path}") + manifest["tables"].append( + { + "model_name": record.model_name, + "index_type": record.index_type.value, + "dimensions": record.dimensions, + "storage_identifier": record.storage_identifier, + "metadata": dict(record.metadata), + "rows": n_rows, + "file": csv_filename, + } + ) + + manifest_path = output_path / SNAPSHOT_MANIFEST_NAME + manifest_path.write_text(json.dumps(manifest, indent=2), encoding="utf-8") + logger.info(f"Wrote pgvector snapshot manifest to {manifest_path}") + + +@app.command() +def import_pgvector( + input_dir: Annotated[str, typer.Option( + "--input-dir", "-i", + help="Directory containing pgvector snapshot files and manifest.json.", + )], + storage_base_dir: Annotated[Optional[str], typer.Option( + "--storage-base-dir", + help="Optional base directory for omop-emb metadata registry (metadata.db). Reverts to `OMOP_EMB_BASE_STORAGE_DIR` if not provided, or defaults to ./.omop_emb in the current working directory. Paths with `~` are expanded.", + )] = None, + replace: Annotated[bool, typer.Option( + "--replace", + help="If set, truncate each destination embedding table before import.", + )] = False, + batch_size: Annotated[int, typer.Option( + "--batch-size", "-b", + help="Number of rows per INSERT batch.", + )] = 5000, + verbosity: Annotated[int, typer.Option( + "--verbose", "-v", count=True, + help="Increase verbosity (up to two levels)" + )] = 0, +): + """Import pgvector embedding tables from a previously exported snapshot.""" + configure_logging_level(verbosity) + load_dotenv() + + input_path = Path(input_dir).resolve() + manifest_path = input_path / SNAPSHOT_MANIFEST_NAME + if not manifest_path.is_file(): + raise FileNotFoundError(f"Snapshot manifest not found at {manifest_path}") + + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + if manifest.get("backend") != BackendType.PGVECTOR.value: + raise ValueError("Snapshot manifest backend is not pgvector.") + + table_specs = manifest.get("tables") + if not isinstance(table_specs, list) or not table_specs: + raise ValueError("Snapshot manifest does not contain any tables.") + + engine = _resolve_engine() + interface = _build_pgvector_interface(storage_base_dir=storage_base_dir) + + for table_spec in table_specs: + model_name = str(table_spec["model_name"]) + index_type_enum = IndexType(str(table_spec["index_type"])) + dimensions = int(table_spec["dimensions"]) + expected_table_name = str(table_spec["storage_identifier"]) + metadata = table_spec.get("metadata") or {} + + interface.register_model( + engine=engine, + model_name=model_name, + dimensions=dimensions, + index_type=index_type_enum, + metadata=metadata, + ) + actual_table_name = interface.get_model_table_name( + model_name=model_name, + index_type=index_type_enum, + ) + if actual_table_name != expected_table_name: + raise RuntimeError( + f"Storage identifier mismatch for model '{model_name}': expected '{expected_table_name}', got '{actual_table_name}'." + ) + + interface.initialise_store(engine) + + with Session(engine) as session: + for table_spec in table_specs: + table_name = str(table_spec["storage_identifier"]) + csv_file = input_path / str(table_spec["file"]) + if not csv_file.is_file(): + raise FileNotFoundError(f"Missing snapshot table file: {csv_file}") + + quoted_table_name = f'"{table_name}"' + if replace: + session.execute(sa.text(f"TRUNCATE TABLE {quoted_table_name}")) + session.commit() + + insert_sql = sa.text( + f""" + INSERT INTO {quoted_table_name} (concept_id, embedding) + VALUES (:concept_id, CAST(:embedding AS vector)) + ON CONFLICT (concept_id) DO UPDATE + SET embedding = EXCLUDED.embedding + """ + ) + + total_rows = 0 + batch: list[dict[str, object]] = [] + with csv_file.open("r", encoding="utf-8", newline="") as handle: + reader = csv.DictReader(handle) + for row in reader: + batch.append( + { + "concept_id": int(row["concept_id"]), + "embedding": row["embedding"], + } + ) + if len(batch) >= batch_size: + session.execute(insert_sql, batch) + session.commit() + total_rows += len(batch) + batch = [] + + if batch: + session.execute(insert_sql, batch) + session.commit() + total_rows += len(batch) + + logger.info(f"Imported {total_rows} rows into {table_name} from {csv_file}") + + +@app.command() +def migrate_legacy_pgvector_registry( + storage_base_dir: Annotated[Optional[str], typer.Option( + "--storage-base-dir", + help="Optional base directory for omop-emb metadata registry (metadata.db). Reverts to `OMOP_EMB_BASE_STORAGE_DIR` if not provided, or defaults to ./.omop_emb in the current working directory. Paths with `~` are expanded.", + )] = None, + source_database_url: Annotated[Optional[str], typer.Option( + "--source-database-url", + help="Source database URL containing the legacy model_registry table. Defaults to OMOP_DATABASE_URL.", + )] = None, + legacy_table: Annotated[str, typer.Option( + "--legacy-table", + help="Name of the legacy table to migrate from.", + )] = "model_registry", + dry_run: Annotated[bool, typer.Option( + "--dry-run", + help="If set, report rows that would be migrated without writing to local metadata.", + )] = False, + drop_legacy_registry: Annotated[bool, typer.Option( + "--drop-legacy-registry", + help="If set, drop the legacy table after successful migration.", + )] = False, + verbosity: Annotated[int, typer.Option( + "--verbose", "-v", count=True, + help="Increase verbosity (up to two levels)" + )] = 0, +): + """Migrate legacy pgvector registry rows into local metadata.db registry.""" + configure_logging_level(verbosity) + load_dotenv() + + source_url = source_database_url or os.getenv("OMOP_DATABASE_URL") + if source_url is None: + raise RuntimeError("OMOP_DATABASE_URL is not set. Provide --source-database-url.") + + source_engine = sa.create_engine(source_url, future=True, echo=False) + interface = _build_pgvector_interface(storage_base_dir=storage_base_dir) + + legacy_rows = _load_legacy_rows(source_engine, legacy_table=legacy_table) + if not legacy_rows: + logger.info("No legacy registry rows found. Nothing to migrate.") + return + + logger.info(f"Found {len(legacy_rows)} legacy registry rows in '{legacy_table}'.") + + migrated = 0 + for row in legacy_rows: + model_name, dimensions, index_type, storage_identifier, metadata = _legacy_row_fields(row) + + if dry_run: + logger.info( + f"[DRY RUN] Would migrate model={model_name}, index={index_type.value}, " + f"dimensions={dimensions}, storage_identifier={storage_identifier}" + ) + migrated += 1 + continue + + interface.backend.embedding_model_registry.register_model( + model_name=model_name, + dimensions=dimensions, + backend_type=BackendType.PGVECTOR, + index_type=index_type, + metadata=metadata, + storage_identifier=storage_identifier, + ) + migrated += 1 + + logger.info(f"Migrated {migrated} legacy registry rows into local metadata registry.") + + if drop_legacy_registry and not dry_run: + with source_engine.begin() as conn: + conn.execute(sa.text(f'DROP TABLE IF EXISTS "{legacy_table}"')) + logger.info(f"Dropped legacy registry table '{legacy_table}'.") + + if __name__ == "__main__": app() diff --git a/src/omop_emb/backends/config.py b/src/omop_emb/config.py similarity index 61% rename from src/omop_emb/backends/config.py rename to src/omop_emb/config.py index 5e8418f..3096dd2 100644 --- a/src/omop_emb/backends/config.py +++ b/src/omop_emb/config.py @@ -2,7 +2,7 @@ from typing import Dict, Tuple ENV_OMOP_EMB_BACKEND = "OMOP_EMB_BACKEND" -ENV_OMOP_EMB_FAISS_INDEX_DIR = "OMOP_EMB_FAISS_INDEX_DIR" +ENV_BASE_STORAGE_DIR = "OMOP_EMB_BASE_STORAGE_DIR" # For backends that use file-based storage, e.g. FAISS with on-disk indices or registry metadata storage class BackendType(StrEnum): @@ -39,6 +39,45 @@ class MetricType(StrEnum): HAMMING = "hamming" JACCARD = "jaccard" + +def parse_backend_type(value: str | BackendType) -> BackendType: + """Normalize a backend value to ``BackendType`` with a clear error message.""" + if isinstance(value, BackendType): + return value + try: + return BackendType(value) + except ValueError as exc: + raise ValueError( + f"Invalid backend type {value!r}. Expected one of " + f"{[member.value for member in BackendType]}." + ) from exc + + +def parse_index_type(value: str | IndexType) -> IndexType: + """Normalize an index value to ``IndexType`` with a clear error message.""" + if isinstance(value, IndexType): + return value + try: + return IndexType(value) + except ValueError as exc: + raise ValueError( + f"Invalid index type {value!r}. Expected one of " + f"{[member.value for member in IndexType]}." + ) from exc + + +def parse_metric_type(value: str | MetricType) -> MetricType: + """Normalize a metric value to ``MetricType`` with a clear error message.""" + if isinstance(value, MetricType): + return value + try: + return MetricType(value) + except ValueError as exc: + raise ValueError( + f"Invalid metric type {value!r}. Expected one of " + f"{[member.value for member in MetricType]}." + ) from exc + # TODO: Support non-flat indices in the future SUPPORTED_INDICES_AND_METRICS_PER_BACKEND: Dict[BackendType, Dict[IndexType, Tuple[MetricType, ...]]] = { #https://github.com/pgvector/pgvector?tab=readme-ov-file#querying diff --git a/src/omop_emb/interface.py b/src/omop_emb/interface.py index 267c645..577570f 100644 --- a/src/omop_emb/interface.py +++ b/src/omop_emb/interface.py @@ -12,11 +12,11 @@ from .backends import ( EmbeddingBackend, - EmbeddingConceptFilter, - EmbeddingModelRecord, get_embedding_backend, ) -from .backends.config import IndexType, MetricType +from omop_emb.utils.embedding_utils import EmbeddingConceptFilter +from omop_emb.model_registry import EmbeddingModelRecord +from .config import BackendType, IndexType, MetricType @dataclass @@ -47,25 +47,43 @@ def embedding_dim(self) -> Optional[int]: def from_backend_name( cls, embedding_client: Optional[LLMClient] = None, - backend_name: Optional[str] = None, - *, - faiss_base_dir: Optional[str] = None, + backend_name: Optional[str | BackendType] = None, + storage_base_dir: Optional[str] = None, + registry_db_name: Optional[str] = None, ) -> EmbeddingInterface: return cls( backend=get_embedding_backend( backend_name=backend_name, - faiss_base_dir=faiss_base_dir, + storage_base_dir=storage_base_dir, + registry_db_name=registry_db_name, ), embedding_client=embedding_client, ) + def setup_and_register_model( + self, + engine: Engine, + model_name: str, + dimensions: int, + index_type: IndexType, + metadata: Mapping[str, object] = {}, + ) -> None: + self.initialise_store(engine) + self.register_model( + engine=engine, + model_name=model_name, + dimensions=dimensions, + index_type=index_type, + metadata=metadata, + ) + def initialise_store(self, engine: Engine) -> None: self.backend.initialise_store(engine) def get_model_table_name( self, - session: Session, model_name: str, + index_type: IndexType, ) -> Optional[str]: """ Legacy helper preserved for compatibility. @@ -73,34 +91,52 @@ def get_model_table_name( Historically this returned a PostgreSQL table name. Under the backend abstraction it returns the backend-specific storage identifier. """ - record = self.backend.get_registered_model(session=session, model_name=model_name) + record = self.backend.get_registered_model(model_name=model_name, index_type=index_type) return record.storage_identifier if record is not None else None def is_model_registered( self, - session: Session, model_name: str, + index_type: IndexType, ) -> bool: - return self.backend.is_model_registered(session=session, model_name=model_name) + return self.backend.is_model_registered(model_name=model_name, index_type=index_type) + def register_model( + self, + engine: Engine, + model_name: str, + dimensions: int, + index_type: IndexType, + metadata: Mapping[str, object] = {}, + ) -> EmbeddingModelRecord: + return self.backend.register_model( + engine=engine, + model_name=model_name, + dimensions=dimensions, + index_type=index_type, + metadata=metadata, + ) def has_any_embeddings( self, session: Session, embedding_model_name: str, + index_type: IndexType, ) -> bool: return self.backend.has_any_embeddings( session=session, model_name=embedding_model_name, + index_type=index_type, ) def get_nearest_concepts( self, session: Session, model_name: str, + index_type: IndexType, query_embedding: np.ndarray, *, - metric_type: MetricType | str, + metric_type: MetricType, concept_filter: Optional[EmbeddingConceptFilter] = None, k: int = 10, ) -> Tuple[Mapping[int, float], ...]: @@ -114,10 +150,12 @@ def get_nearest_concepts( SQLAlchemy session for any required relational access. model_name : str Name of the embedding model used to create the embeddings. + index_type : IndexType + The type of vector index used to store the embeddings. query_embedding : ndarray The embedding vector to search with. Expected shape is (q, dimension) where q is the number of query vectors and dimension is the size of the embedding space for the model. - metric_type : MetricType, str + metric_type : MetricType The similarity or distance metric to use for nearest neighbor search. This must be compatible with the index type used by the database. concept_filter : Optional[EmbeddingConceptFilter], optional A filter to specify which concepts to consider as potential nearest neighbors. @@ -133,13 +171,15 @@ def get_nearest_concepts( Tuple[Mapping[int, float], ...] A tuple of dictionaries containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner dictionary contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, k) where q is the number of query vectors and k is the number of nearest neighbors returned per query. """ - if isinstance(metric_type, str): - metric_type = MetricType(metric_type) - assert isinstance(metric_type, MetricType), f"metric_type must be a MetricType enum or string. Got {type(metric_type)}" + if not isinstance(metric_type, MetricType): + raise TypeError( + f"metric_type must be MetricType, got {type(metric_type).__name__}." + ) nearest_concepts = self.backend.get_nearest_concepts( session=session, model_name=model_name, - query_embedding=query_embedding, + index_type=index_type, + query_embeddings=query_embedding, concept_filter=concept_filter, metric_type=metric_type, k=k @@ -150,6 +190,7 @@ def get_nearest_concepts_by_texts( self, session: Session, embedding_model_name: str, + index_type: IndexType, query_texts: str | Tuple[str, ...] | List[str], *, metric_type: MetricType, @@ -167,6 +208,8 @@ def get_nearest_concepts_by_texts( SQLAlchemy session for any required relational access. embedding_model_name : str Name of the embedding model used to create the embeddings. + index_type : IndexType + The type of vector index used to store the embeddings. query_texts : str | Tuple[str, ...] | List[str] The text(s) to embed and search with. If a single string is provided, it will be embedded and searched as one query. If a tuple or list of strings is provided, each string will be embedded and searched separately, and the results will be returned in the same order as the input texts. metric_type : MetricType @@ -193,6 +236,7 @@ def get_nearest_concepts_by_texts( return self.get_nearest_concepts( session=session, model_name=embedding_model_name, + index_type=index_type, query_embedding=query_embeddings, metric_type=metric_type, concept_filter=concept_filter, @@ -203,11 +247,13 @@ def get_embeddings_by_concept_ids( self, session: Session, embedding_model_name: str, + index_type: IndexType, concept_ids: Tuple[int, ...], ) -> Mapping[int, Sequence[float]]: return self.backend.get_embeddings_by_concept_ids( session=session, model_name=embedding_model_name, + index_type=index_type, concept_ids=concept_ids, ) @@ -221,6 +267,7 @@ def initialise_tables(self, engine: Engine): def add_to_db( self, session: Session, + index_type: IndexType, concept_ids: Tuple[int, ...], embeddings: ndarray, model: str, @@ -234,38 +281,10 @@ def add_to_db( return self.backend.upsert_embeddings( session=session, model_name=model, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) - - def ensure_model_registered( - self, - *, - engine: Engine, - session: Session, - model_name: str, - dimensions: int, - index_type: IndexType, - metadata: Mapping[str, object] = {}, - ) -> EmbeddingModelRecord: - """ - Ensure the embedding model exists in the selected backend. - """ - - existing = self.backend.get_registered_model( - session=session, - model_name=model_name, - ) - if existing is not None: - return existing - else: - return self.backend.register_model( - engine=engine, - model_name=model_name, - dimensions=dimensions, - index_type=index_type, - metadata=metadata, - ) def embed_texts( self, @@ -284,12 +303,14 @@ def upsert_concept_embeddings( *, session: Session, model_name: str, + index_type: IndexType, concept_ids: Sequence[int], embeddings: ndarray, ) -> None: self.backend.upsert_embeddings( session=session, model_name=model_name, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -299,6 +320,7 @@ def embed_and_upsert_concepts( *, session: Session, model_name: str, + index_type: IndexType, concept_ids: Sequence[int], concept_texts: Sequence[str], embedding_client: Optional[LLMClient] = None, @@ -317,6 +339,7 @@ def embed_and_upsert_concepts( self.upsert_concept_embeddings( session=session, model_name=model_name, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -327,27 +350,29 @@ def get_concepts_without_embedding( *, session: Session, model_name: str, + index_type: IndexType, concept_filter: Optional[EmbeddingConceptFilter] = None, limit: Optional[int] = None, ) -> Mapping[int, str]: return self.backend.get_concepts_without_embedding( session=session, model_name=model_name, + index_type=index_type, concept_filter=concept_filter, limit=limit, ) - def get_concepts_without_embedding_query( + def q_get_concepts_without_embedding( self, *, - session: Session, model_name: str, + index_type: IndexType, concept_filter: Optional[EmbeddingConceptFilter] = None, limit: Optional[int] = None, ) -> Select: - return self.backend.get_concepts_without_embedding_query( - session=session, + return self.backend.q_get_concepts_without_embedding( model_name=model_name, + index_type=index_type, concept_filter=concept_filter, limit=limit, ) @@ -357,10 +382,12 @@ def get_concepts_without_embedding_count( *, session: Session, model_name: str, + index_type: IndexType, concept_filter: Optional[EmbeddingConceptFilter] = None, ) -> int: return self.backend.get_concepts_without_embedding_count( session=session, model_name=model_name, + index_type=index_type, concept_filter=concept_filter, ) diff --git a/src/omop_emb/model_registry/__init__.py b/src/omop_emb/model_registry/__init__.py new file mode 100644 index 0000000..47c7feb --- /dev/null +++ b/src/omop_emb/model_registry/__init__.py @@ -0,0 +1,2 @@ +from .model_registry_types import EmbeddingModelRecord +from .model_registry_manager import ModelRegistryManager \ No newline at end of file diff --git a/src/omop_emb/backends/registry.py b/src/omop_emb/model_registry/model_registry_cdm.py similarity index 74% rename from src/omop_emb/backends/registry.py rename to src/omop_emb/model_registry/model_registry_cdm.py index 868b5cb..5dda099 100644 --- a/src/omop_emb/backends/registry.py +++ b/src/omop_emb/model_registry/model_registry_cdm.py @@ -1,18 +1,20 @@ from __future__ import annotations from sqlalchemy import DateTime, Engine, Integer, JSON, String, func, inspect, text, Enum -from sqlalchemy.orm import mapped_column, validates +from sqlalchemy.orm import DeclarativeBase, mapped_column, validates -from orm_loader.helpers import Base - -from .config import ( +from ..config import ( is_index_type_supported_for_backend, get_supported_index_types_for_backend, IndexType, BackendType ) -class ModelRegistry(Base): +class ModelRegistryBase(DeclarativeBase): + """Dedicated declarative base for local model registry metadata.""" + + +class ModelRegistry(ModelRegistryBase): """ Shared database-backed registry for embedding models across backends. @@ -26,33 +28,34 @@ class ModelRegistry(Base): __tablename__ = "model_registry" - # TODO: Think about having multiple models per index and backend. Would require to - # have the storage_identifier (i.e. the table_name) to be unique as well, which is - # currently not enforced explicitly but implicitily as it only depends on model_name that is unique model_name = mapped_column(String, primary_key=True) + backend_type = mapped_column(Enum(BackendType, native_enum=False), nullable=False, primary_key=True) + index_type = mapped_column(Enum(IndexType, native_enum=False), nullable=False, primary_key=True) dimensions = mapped_column(Integer, nullable=False) storage_identifier = mapped_column("table_name", String, unique=True, nullable=False) - index_type = mapped_column(Enum(IndexType, native_enum=False), nullable=False) - backend_type = mapped_column(Enum(BackendType, native_enum=False), nullable=False) details = mapped_column(JSON, nullable=True, default=dict) created_at = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) updated_at = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()) + + @validates("backend_type") + def validate_backend_type(self, key, backend_type): + if backend_type not in BackendType: + raise ValueError(f"Unsupported backend type: {backend_type}. Supported backends: {list(BackendType)}") + return backend_type + @validates("index_type") def validate_index_for_backend(self, key, index_type): - if is_index_type_supported_for_backend(self.backend_type, index_type): + if self.backend_type is None: + raise ValueError("Instantiate the registry entry with a valid backend_type before setting index_type.") + + if not is_index_type_supported_for_backend(self.backend_type, index_type): raise ValueError( f"Backend {self.backend_type} does not support {index_type}. " f"Supported: {get_supported_index_types_for_backend(self.backend_type)}" ) return index_type - - @validates("backend_type") - def validate_backend_type(self, key, backend_type): - if backend_type not in BackendType: - raise ValueError(f"Unsupported backend type: {backend_type}. Supported backends: {list(BackendType)}") - return backend_type def ensure_model_registry_schema(engine: Engine) -> None: - Base.metadata.create_all(engine, tables=[ModelRegistry.__table__]) # type: ignore[arg-type] + ModelRegistryBase.metadata.create_all(engine, tables=[ModelRegistry.__table__]) # type: ignore[arg-type] diff --git a/src/omop_emb/model_registry/model_registry_manager.py b/src/omop_emb/model_registry/model_registry_manager.py new file mode 100644 index 0000000..947c9e4 --- /dev/null +++ b/src/omop_emb/model_registry/model_registry_manager.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from sqlalchemy import Engine, create_engine, and_, select +from sqlalchemy.orm import Session +import logging +from typing import Optional, Mapping +import pathlib +import re + +from omop_emb.config import ( + IndexType, + BackendType +) +from .model_registry_cdm import ModelRegistry, ensure_model_registry_schema +from .model_registry_types import EmbeddingModelRecord +from omop_emb.utils.errors import ModelRegistrationConflictError + +logger = logging.getLogger(__name__) + + + +class ModelRegistryManager: + """Manages model registry (metadata) for embedding models locally in a separate SQLite database.""" + DB_FILENAME = "metadata.db" + def __init__( + self, + base_dir: str | pathlib.Path, + db_file: Optional[str] = None, + ): + db_file = db_file or self.DB_FILENAME + self.db_path = pathlib.Path(base_dir) / db_file + self.engine = create_engine(f"sqlite:///{self.db_path}") + + ensure_model_registry_schema(self.engine) + + def get_registered_models_from_db( + self, + backend_type: Optional[BackendType] = None, + model_name: Optional[str] = None, + index_type: Optional[IndexType] = None, + ) -> Optional[tuple[EmbeddingModelRecord, ...]]: + """ + Get the registered embedding models of the database, optionally filtered by backend type, model name, or index type. + """ + stmt = select(ModelRegistry) + if backend_type is not None: + stmt = stmt.where(ModelRegistry.backend_type == backend_type) + if model_name is not None: + stmt = stmt.where(ModelRegistry.model_name == model_name) + if index_type is not None: + stmt = stmt.where(ModelRegistry.index_type == index_type) + + with Session(self.engine, expire_on_commit=False) as session: + existing_models = session.scalars(stmt).all() + + if not existing_models: + return None + + return tuple(self._registry_entry_to_model_record(match) for match in existing_models) + + + def register_model( + self, + model_name: str, + dimensions: int, + *, + backend_type: BackendType, + index_type: IndexType, + metadata: Optional[Mapping[str, object]] = None, + storage_identifier: Optional[str] = None, + ) -> EmbeddingModelRecord: + """ + Shared template method for model registration. + """ + with Session(self.engine, expire_on_commit=False) as session: + existing_row = session.scalar( + select(ModelRegistry).where( + and_( + ModelRegistry.model_name == model_name, + ModelRegistry.backend_type == backend_type, + ModelRegistry.index_type == index_type, + ) + ) + ) + if existing_row is not None: + if existing_row.backend_type != backend_type: + raise ModelRegistrationConflictError( + f"Model '{model_name}' is already registered with backend " + f"'{existing_row.backend_type}', not '{backend_type}'. " + "Reuse the existing model name or choose a new one.", + conflict_field="backend_type" + ) + + if existing_row.dimensions != dimensions: + raise ModelRegistrationConflictError( + f"Model '{model_name}' is already registered with dimensions " + f"{existing_row.dimensions}, not {dimensions}.", + conflict_field="dimensions" + ) + if existing_row.index_type != index_type: + raise ModelRegistrationConflictError( + f"Model '{model_name}' is already registered with " + f"index_method='{existing_row.index_type}', not " + f"'{index_type}'. Reuse the existing model " + "configuration or register a new model name.", + conflict_field="index_type" + ) + if existing_row.details != metadata: + raise ModelRegistrationConflictError( + f"Model '{model_name}' is already registered with different " + f"metadata. Reuse the existing model name or choose a new one.", + conflict_field="metadata" + ) + if storage_identifier is not None and existing_row.storage_identifier != storage_identifier: + raise ModelRegistrationConflictError( + f"Model '{model_name}' is already registered with storage identifier " + f"'{existing_row.storage_identifier}', not '{storage_identifier}'.", + conflict_field="storage_identifier" + ) + return self._registry_entry_to_model_record(existing_row) + + metadata = metadata or {} + + safe_name = self.safe_model_name(model_name) + storage_name = storage_identifier or self.storage_name( + safe_model_name=safe_name, + index_type=index_type, + backend_type=backend_type + ) + + new_entry = ModelRegistry( + model_name=model_name, + backend_type=backend_type, + index_type=index_type, + dimensions=dimensions, + storage_identifier=storage_name, + details=metadata + ) + + with Session(self.engine, expire_on_commit=False) as session: + session.add(new_entry) + session.commit() + return self._registry_entry_to_model_record(new_entry) + + + @staticmethod + def safe_model_name(model_name: str) -> str: + name = model_name.lower() + sanitized = re.sub(r"[^\w]+", "_", name) + sanitized = re.sub(r"_+", "_", sanitized).strip("_") + return sanitized + + @staticmethod + def storage_name( + safe_model_name: str, + index_type: IndexType, + backend_type: BackendType + ) -> str: + return f"{backend_type.value.lower()}_{safe_model_name}_{index_type.value}" + + @staticmethod + def _coerce_registry_metadata(value: object) -> Mapping[str, object]: + if isinstance(value, Mapping): + return dict(value) + return {} + + @staticmethod + def _registry_entry_to_model_record(entry: ModelRegistry) -> EmbeddingModelRecord: + return EmbeddingModelRecord( + model_name=entry.model_name, + dimensions=entry.dimensions, + backend_type=entry.backend_type, + index_type=entry.index_type, + storage_identifier=entry.storage_identifier, + metadata=entry.details, + ) diff --git a/src/omop_emb/model_registry/model_registry_types.py b/src/omop_emb/model_registry/model_registry_types.py new file mode 100644 index 0000000..4cc4acd --- /dev/null +++ b/src/omop_emb/model_registry/model_registry_types.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Mapping + +from omop_emb.config import ( + IndexType, + BackendType +) + +@dataclass(frozen=True) +class EmbeddingModelRecord: + """ + Canonical description of a registered embedding model. + + ``storage_identifier`` is intentionally backend-specific. For example: + - PostgreSQL backend: dynamic embedding table name + - FAISS backend: on-disk index path or logical collection name + """ + + model_name: str + dimensions: int + backend_type: BackendType + index_type: IndexType + storage_identifier: str + metadata: Mapping[str, object] = field(default_factory=dict) diff --git a/src/omop_emb/utils/__init__.py b/src/omop_emb/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/omop_emb/backends/embedding_utils.py b/src/omop_emb/utils/embedding_utils.py similarity index 64% rename from src/omop_emb/backends/embedding_utils.py rename to src/omop_emb/utils/embedding_utils.py index 5f66576..202d26a 100644 --- a/src/omop_emb/backends/embedding_utils.py +++ b/src/omop_emb/utils/embedding_utils.py @@ -4,25 +4,7 @@ from sqlalchemy import Select from omop_alchemy.cdm.model.vocabulary import Concept -from .config import BackendType, SUPPORTED_INDICES_AND_METRICS_PER_BACKEND, IndexType, MetricType - - -@dataclass(frozen=True) -class EmbeddingModelRecord: - """ - Canonical description of a registered embedding model. - - ``storage_identifier`` is intentionally backend-specific. For example: - - PostgreSQL backend: dynamic embedding table name - - FAISS backend: on-disk index path or logical collection name - """ - - model_name: str - dimensions: int - backend_type: BackendType - index_type: IndexType - storage_identifier: Optional[str] = None - metadata: Mapping[str, object] = field(default_factory=dict) +from ..config import BackendType, IndexType, MetricType @dataclass(frozen=True) @@ -72,27 +54,6 @@ class NearestConceptMatch: is_active: bool -@dataclass(frozen=True) -class EmbeddingBackendCapabilities: - """ - Capability flags for a backend implementation. - - These are not used by the current code yet, but they make backend - differences explicit. For example, a FAISS backend might support nearest - neighbor search but require explicit refreshes after bulk writes. - """ - - stores_embeddings: bool = True - supports_incremental_upsert: bool = True - supports_nearest_neighbor_search: bool = True - supports_server_side_similarity: bool = True - supports_filter_by_concept_ids: bool = True - supports_filter_by_domain: bool = True - supports_filter_by_vocabulary: bool = True - supports_filter_by_standard_flag: bool = True - requires_explicit_index_refresh: bool = False - - def get_similarity_from_distance(distance_col, metric: MetricType): """ Helper to map various distance metrics to a 0.0 - 1.0 similarity score. diff --git a/src/omop_emb/backends/errors.py b/src/omop_emb/utils/errors.py similarity index 100% rename from src/omop_emb/backends/errors.py rename to src/omop_emb/utils/errors.py diff --git a/tests/conftest.py b/tests/conftest.py index 44c9480..54c27f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,6 @@ import pytest import numpy as np import sqlalchemy as sa -from sqlalchemy.engine import make_url from sqlalchemy.orm import Session, sessionmaker @@ -23,9 +22,7 @@ from omop_emb.backends.faiss import FaissEmbeddingBackend from omop_emb.backends.pgvector import PGVectorEmbeddingBackend -from omop_emb.backends.registry import ensure_model_registry_schema, ModelRegistry from omop_emb.interface import EmbeddingInterface -from omop_emb.backends.config import IndexType TEST_DB_NAME = os.getenv("TEST_DATABASE_NAME", "test_omop_emb") @@ -90,8 +87,6 @@ def _create_test_database() -> sa.URL: conn.execute(sa.text(f"DROP ROLE IF EXISTS {TEST_USERNAME}")) conn.execute(sa.text(f"CREATE USER {TEST_USERNAME} WITH PASSWORD '{TEST_PASSWORD}' SUPERUSER")) conn.execute(sa.text(f'CREATE DATABASE "{TEST_DB_NAME}" OWNER {TEST_USERNAME}')) - conn.execute(sa.text(f"DROP TABLE IF EXISTS {ModelRegistry.__tablename__} CASCADE;")) - finally: admin_engine.dispose() @@ -135,7 +130,6 @@ def pg_engine(): # Create all tables Base.metadata.create_all(engine) - ensure_model_registry_schema(engine) yield engine @@ -158,11 +152,8 @@ def session(pg_engine) -> Generator[Session, None, None]: # Clear tables before test with pg_engine.connect() as conn: with conn.begin(): - res = conn.execute(sa.text(f"SELECT table_name FROM {ModelRegistry.__tablename__}")) - for (table_name,) in res: - conn.execute(sa.text(f"DROP TABLE IF EXISTS {table_name} CASCADE")) - conn.execute(sa.text(f"TRUNCATE TABLE concept, {ModelRegistry.__tablename__} CASCADE")) - + conn.execute(sa.text(f"TRUNCATE TABLE concept CASCADE")) + Session = sessionmaker(bind=pg_engine, future=True) db_session = Session() @@ -194,24 +185,24 @@ def create_embeddings(concept_names: list[str] | str, batch_size: Optional[int] @pytest.fixture -def temp_faiss_dir(): +def temp_storage_dir(): """Temporary directory for FAISS indices.""" with tempfile.TemporaryDirectory() as tmpdir: yield Path(tmpdir) @pytest.fixture -def faiss_backend(session, temp_faiss_dir) -> FaissEmbeddingBackend: +def faiss_backend(session, temp_storage_dir) -> FaissEmbeddingBackend: """FAISS backend with model registry initialized.""" - backend = FaissEmbeddingBackend(base_dir=temp_faiss_dir) + backend = FaissEmbeddingBackend(storage_base_dir=temp_storage_dir) backend.initialise_store(session.bind) return backend @pytest.fixture -def pgvector_backend(session) -> PGVectorEmbeddingBackend: +def pgvector_backend(session, temp_storage_dir) -> PGVectorEmbeddingBackend: """PGVector backend with vector extension and model registry initialized.""" - backend = PGVectorEmbeddingBackend() + backend = PGVectorEmbeddingBackend(storage_base_dir=temp_storage_dir) backend.initialise_store(session.bind) return backend diff --git a/tests/shared_backend_tests.py b/tests/shared_backend_tests.py index 2b6c8f8..0a9dca4 100644 --- a/tests/shared_backend_tests.py +++ b/tests/shared_backend_tests.py @@ -3,48 +3,48 @@ import pytest import numpy as np -from omop_emb.backends.config import IndexType, MetricType -from omop_emb.backends.embedding_utils import EmbeddingConceptFilter -from omop_emb.backends.errors import ModelRegistrationConflictError +from omop_emb.config import IndexType, MetricType +from omop_emb.utils.embedding_utils import EmbeddingConceptFilter +from omop_emb.utils.errors import ModelRegistrationConflictError from .conftest import CONCEPTS, MODEL_NAME, EMBEDDING_DIM, TEST_CONCEPT_EMB class SharedBackendTests: """Base class containing backend-parity tests via a generic ``backend`` fixture.""" - def test_backend_registration(self, session, backend): + def test_backend_registration(self, session, backend, index_type: IndexType = IndexType.FLAT): """Test registering model with the backend.""" model = backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) assert model.model_name == MODEL_NAME assert model.dimensions == EMBEDDING_DIM - def test_get_registered_model(self, session, backend): + def test_get_registered_model(self, session, backend, index_type: IndexType = IndexType.FLAT): """Test retrieving registered model.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) - retrieved = backend.get_registered_model(session, MODEL_NAME) + retrieved = backend.get_registered_model(model_name=MODEL_NAME, index_type=index_type) assert retrieved is not None assert retrieved.model_name == MODEL_NAME - def test_duplicate_registration_identical_success(self, session, backend): + def test_duplicate_registration_identical_success(self, session, backend, index_type: IndexType = IndexType.FLAT): """Test that idempotent model registration returns the existing record.""" params = { "engine": session.bind, "model_name": MODEL_NAME, "dimensions": EMBEDDING_DIM, - "index_type": IndexType.FLAT, + "index_type": index_type, "metadata": {"version": "1.0"}, } @@ -54,13 +54,13 @@ def test_duplicate_registration_identical_success(self, session, backend): assert model1.storage_identifier == model2.storage_identifier assert model2.metadata["version"] == "1.0" - def test_registration_metadata_conflict_raises_error(self, session, backend): + def test_registration_metadata_conflict_raises_error(self, session, backend, index_type: IndexType = IndexType.FLAT): """Test conflicting metadata on re-registration raises conflict error.""" params = { "engine": session.bind, "model_name": MODEL_NAME, "dimensions": EMBEDDING_DIM, - "index_type": IndexType.FLAT, + "index_type": index_type, "metadata": {"version": "1.0"}, } backend.register_model(**params) @@ -73,13 +73,13 @@ def test_registration_metadata_conflict_raises_error(self, session, backend): assert excinfo.value.conflict_field == "metadata" - def test_upsert_embeddings(self, session, backend, mock_llm_client): + def test_upsert_embeddings(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test upserting embeddings.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concept = CONCEPTS["Hypertension"] @@ -89,19 +89,20 @@ def test_upsert_embeddings(self, session, backend, mock_llm_client): backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) - assert backend.has_any_embeddings(session, MODEL_NAME) + assert backend.has_any_embeddings(session=session, model_name=MODEL_NAME, index_type=index_type) - def test_get_embeddings_by_ids(self, session, backend, mock_llm_client): + def test_get_embeddings_by_ids(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test retrieving embeddings by concept IDs.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concepts = [CONCEPTS["Hypertension"], CONCEPTS["Diabetes"]] @@ -111,6 +112,7 @@ def test_get_embeddings_by_ids(self, session, backend, mock_llm_client): backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -119,18 +121,19 @@ def test_get_embeddings_by_ids(self, session, backend, mock_llm_client): session=session, model_name=MODEL_NAME, concept_ids=concept_ids, + index_type=index_type, ) assert len(retrieved) == 2 assert set(retrieved.keys()) == set(concept_ids) - def test_nearest_neighbor_search(self, session, backend, mock_llm_client): + def test_nearest_neighbor_search(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test exact top-1 nearest-neighbor retrieval with deterministic embeddings.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concepts = list(CONCEPTS.values()) @@ -140,6 +143,7 @@ def test_nearest_neighbor_search(self, session, backend, mock_llm_client): backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -148,6 +152,7 @@ def test_nearest_neighbor_search(self, session, backend, mock_llm_client): results = backend.get_nearest_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, query_embeddings=query_embeddings, metric_type=MetricType.L2, k=1, @@ -157,13 +162,13 @@ def test_nearest_neighbor_search(self, session, backend, mock_llm_client): assert len(results[0]) == 1 assert results[0][0].concept_id == CONCEPTS["Hypertension"].concept_id - def test_nearest_neighbor_with_domain_filter(self, session, backend, mock_llm_client): + def test_nearest_neighbor_with_domain_filter(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test nearest neighbor search with domain filter.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concepts = list(CONCEPTS.values()) @@ -173,6 +178,7 @@ def test_nearest_neighbor_with_domain_filter(self, session, backend, mock_llm_cl backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -180,6 +186,7 @@ def test_nearest_neighbor_with_domain_filter(self, session, backend, mock_llm_cl results = backend.get_nearest_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, query_embeddings=TEST_CONCEPT_EMB, concept_filter=EmbeddingConceptFilter(domains=("Condition",)), metric_type=MetricType.L2, @@ -190,13 +197,13 @@ def test_nearest_neighbor_with_domain_filter(self, session, backend, mock_llm_cl assert len(results) == 1 assert set(m.concept_id for m in results[0]) == expected_ids - def test_nearest_neighbor_with_vocabulary_filter(self, session, backend, mock_llm_client): + def test_nearest_neighbor_with_vocabulary_filter(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test nearest neighbor search with vocabulary filter.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concepts = list(CONCEPTS.values()) @@ -206,6 +213,7 @@ def test_nearest_neighbor_with_vocabulary_filter(self, session, backend, mock_ll backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -213,6 +221,7 @@ def test_nearest_neighbor_with_vocabulary_filter(self, session, backend, mock_ll results = backend.get_nearest_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, query_embeddings=TEST_CONCEPT_EMB, concept_filter=EmbeddingConceptFilter(vocabularies=("RxNorm",)), metric_type=MetricType.L2, @@ -223,13 +232,13 @@ def test_nearest_neighbor_with_vocabulary_filter(self, session, backend, mock_ll assert len(results[0]) == 1 assert results[0][0].concept_id == CONCEPTS["Aspirin"].concept_id - def test_get_concepts_without_embedding(self, session, backend, mock_llm_client): + def test_get_concepts_without_embedding(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test retrieving concepts without embeddings.""" backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concept = CONCEPTS["Hypertension"] @@ -239,6 +248,7 @@ def test_get_concepts_without_embedding(self, session, backend, mock_llm_client) backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -246,13 +256,14 @@ def test_get_concepts_without_embedding(self, session, backend, mock_llm_client) unembedded = backend.get_concepts_without_embedding( session=session, model_name=MODEL_NAME, + index_type=index_type, ) assert CONCEPTS["Aspirin"].concept_id in unembedded assert CONCEPTS["Diabetes"].concept_id in unembedded assert CONCEPTS["Hypertension"].concept_id not in unembedded - def test_l2_similarity_exact_values(self, session, backend, mock_llm_client): + def test_l2_similarity_exact_values(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test L2 distance calculations yield expected similarity scores. With deterministic embeddings: @@ -264,8 +275,8 @@ def test_l2_similarity_exact_values(self, session, backend, mock_llm_client): backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concepts = list(CONCEPTS.values()) @@ -275,6 +286,7 @@ def test_l2_similarity_exact_values(self, session, backend, mock_llm_client): backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -282,6 +294,7 @@ def test_l2_similarity_exact_values(self, session, backend, mock_llm_client): results = backend.get_nearest_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, query_embeddings=TEST_CONCEPT_EMB, metric_type=MetricType.L2, k=10, @@ -301,7 +314,7 @@ def test_l2_similarity_exact_values(self, session, backend, mock_llm_client): rtol=1e-5, ) - def test_cosine_similarity_exact_values(self, session, backend, mock_llm_client): + def test_cosine_similarity_exact_values(self, session, backend, mock_llm_client, index_type: IndexType = IndexType.FLAT): """Test cosine distance calculations yield expected similarity scores. With deterministic embeddings (1D vectors): @@ -316,8 +329,8 @@ def test_cosine_similarity_exact_values(self, session, backend, mock_llm_client) backend.register_model( engine=session.bind, model_name=MODEL_NAME, + index_type=index_type, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, ) test_concepts = list(CONCEPTS.values()) @@ -327,6 +340,7 @@ def test_cosine_similarity_exact_values(self, session, backend, mock_llm_client) backend.upsert_embeddings( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, embeddings=embeddings, ) @@ -334,6 +348,7 @@ def test_cosine_similarity_exact_values(self, session, backend, mock_llm_client) results = backend.get_nearest_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, query_embeddings=TEST_CONCEPT_EMB, metric_type=MetricType.COSINE, k=10, diff --git a/tests/test_cli_pgvector_snapshot.py b/tests/test_cli_pgvector_snapshot.py new file mode 100644 index 0000000..bb3f2e8 --- /dev/null +++ b/tests/test_cli_pgvector_snapshot.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pytest +import sqlalchemy as sa +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session + +from omop_emb.cli import ( + export_pgvector, + import_pgvector, + migrate_legacy_pgvector_registry, +) +from omop_emb.config import BackendType, IndexType +from omop_emb.interface import EmbeddingInterface + +from .conftest import CONCEPTS, EMBEDDING_DIM, MODEL_NAME + + +@pytest.mark.pgvector +@pytest.mark.unit +def test_pgvector_export_import_roundtrip( + session: Session, + mock_llm_client, + temp_storage_dir: Path, +) -> None: + engine = session.get_bind() + assert isinstance(engine, Engine) + + source_storage = temp_storage_dir / "source_registry" + source_storage.mkdir(parents=True, exist_ok=True) + + interface = EmbeddingInterface.from_backend_name( + backend_name=BackendType.PGVECTOR, + storage_base_dir=str(source_storage), + embedding_client=mock_llm_client, + ) + interface.initialise_store(engine) + + interface.register_model( + engine=engine, + model_name=MODEL_NAME, + dimensions=EMBEDDING_DIM, + index_type=IndexType.FLAT, + metadata={"origin": "roundtrip-test"}, + ) + + concept_names = ["Hypertension", "Diabetes"] + concept_ids = tuple(CONCEPTS[name].concept_id for name in concept_names) + embeddings = mock_llm_client.embeddings(concept_names) + + interface.add_to_db( + session=session, + index_type=IndexType.FLAT, + concept_ids=concept_ids, + embeddings=embeddings, + model=MODEL_NAME, + ) + + engine_url = engine.url.render_as_string(hide_password=False) + + snapshot_dir = temp_storage_dir / "snapshot" + with pytest.MonkeyPatch.context() as mp: + mp.setenv("OMOP_DATABASE_URL", engine_url) + export_pgvector( + output_dir=str(snapshot_dir), + storage_base_dir=str(source_storage), + model=[MODEL_NAME], + index_type=IndexType.FLAT, + ) + + manifest_path = snapshot_dir / "manifest.json" + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + assert manifest["backend"] == BackendType.PGVECTOR.value + assert len(manifest["tables"]) == 1 + + storage_identifier = manifest["tables"][0]["storage_identifier"] + table_name = f'"{storage_identifier}"' + + with engine.begin() as conn: + conn.execute(sa.text(f"TRUNCATE TABLE {table_name}")) + + target_storage = temp_storage_dir / "target_registry" + target_storage.mkdir(parents=True, exist_ok=True) + + with pytest.MonkeyPatch.context() as mp: + mp.setenv("OMOP_DATABASE_URL", engine_url) + import_pgvector( + input_dir=str(snapshot_dir), + storage_base_dir=str(target_storage), + replace=True, + batch_size=2, + ) + + with engine.connect() as conn: + count = int(conn.execute(sa.text(f"SELECT COUNT(*) FROM {table_name}")).scalar_one()) + assert count == len(concept_ids) + + +@pytest.mark.pgvector +@pytest.mark.unit +def test_migrate_legacy_pgvector_registry( + session: Session, + temp_storage_dir: Path, +) -> None: + engine = session.get_bind() + assert isinstance(engine, Engine) + + legacy_table = "model_registry_legacy" + + with engine.begin() as conn: + conn.execute( + sa.text( + f""" + CREATE TABLE {legacy_table} ( + model_name TEXT NOT NULL, + dimensions INTEGER NOT NULL, + index_type TEXT NOT NULL, + table_name TEXT NOT NULL, + details JSONB + ) + """ + ) + ) + conn.execute( + sa.text( + f""" + INSERT INTO {legacy_table} (model_name, dimensions, index_type, table_name, details) + VALUES (:model_name, :dimensions, :index_type, :table_name, CAST(:details AS JSONB)) + """ + ), + { + "model_name": "legacy-model", + "dimensions": 1, + "index_type": "flat", + "table_name": "legacy_table_flat", + "details": '{"migrated": true}', + }, + ) + + source_url = engine.url.render_as_string(hide_password=False) + storage_dir = temp_storage_dir / "migrated_registry" + storage_dir.mkdir(parents=True, exist_ok=True) + + migrate_legacy_pgvector_registry( + storage_base_dir=str(storage_dir), + source_database_url=source_url, + legacy_table=legacy_table, + dry_run=True, + drop_legacy_registry=False, + ) + + interface = EmbeddingInterface.from_backend_name( + backend_name=BackendType.PGVECTOR, + storage_base_dir=str(storage_dir), + ) + migrated_before = interface.backend.embedding_model_registry.get_registered_models_from_db( + backend_type=BackendType.PGVECTOR, + model_name="legacy-model", + index_type=IndexType.FLAT, + ) + assert migrated_before is None + + migrate_legacy_pgvector_registry( + storage_base_dir=str(storage_dir), + source_database_url=source_url, + legacy_table=legacy_table, + dry_run=False, + drop_legacy_registry=True, + ) + + migrated = interface.backend.embedding_model_registry.get_registered_models_from_db( + backend_type=BackendType.PGVECTOR, + model_name="legacy-model", + index_type=IndexType.FLAT, + ) + assert migrated is not None + assert len(migrated) == 1 + assert migrated[0].storage_identifier == "legacy_table_flat" + + inspector = sa.inspect(engine) + assert not inspector.has_table(legacy_table) diff --git a/tests/test_config.py b/tests/test_config.py index 824a09a..00c29be 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,7 +3,7 @@ import pytest from omop_emb.backends.factory import normalize_backend_name, get_embedding_backend -from omop_emb.backends.config import BackendType, IndexType, MetricType +from omop_emb.config import BackendType, IndexType, MetricType @pytest.mark.unit @@ -17,18 +17,23 @@ def test_normalize_backend_name(self): def test_get_faiss_backend(self, tmp_path): """Test getting FAISS backend.""" - backend = get_embedding_backend("faiss", faiss_base_dir=str(tmp_path)) + backend = get_embedding_backend("faiss", storage_base_dir=str(tmp_path)) assert backend.backend_type == BackendType.FAISS + + def test_get_backend_rejects_relative_storage_base_dir(self): + """Relative storage directories are rejected to avoid cwd-dependent registry paths.""" + with pytest.raises(ValueError, match="absolute path"): + get_embedding_backend("faiss", storage_base_dir=".operator") def test_faiss_supports_flat_index(self): """Test FAISS supports FLAT index.""" - from omop_emb.backends.config import is_index_type_supported_for_backend + from omop_emb.config import is_index_type_supported_for_backend assert is_index_type_supported_for_backend(BackendType.FAISS, IndexType.FLAT) def test_faiss_supports_cosine_metric(self): """Test FAISS supports COSINE metric.""" - from omop_emb.backends.config import is_supported_index_metric_combination_for_backend + from omop_emb.config import is_supported_index_metric_combination_for_backend assert is_supported_index_metric_combination_for_backend( BackendType.FAISS, diff --git a/tests/test_dummy.py b/tests/test_dummy.py deleted file mode 100644 index d7f3c09..0000000 --- a/tests/test_dummy.py +++ /dev/null @@ -1,3 +0,0 @@ -def test_placeholder(): - # This test always passes to keep the CI happy until you write real tests - assert True diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index d5af158..b4bfac4 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -63,10 +63,6 @@ def test_mock_llm_client_generates_embeddings(self, mock_llm_client): # Should have correct shape assert emb1.shape == (1, mock_llm_client.embedding_dim) - def test_faiss_backend_initializes(self, faiss_backend, session): - """Verify FAISS backend initializes without error.""" - assert faiss_backend is not None - assert faiss_backend.base_dir.exists() def test_embedding_interface_initializes(self, embedding_interface, session): """Verify EmbeddingInterface initializes with mocks.""" diff --git a/tests/test_interface.py b/tests/test_interface.py index a49ff0f..978a651 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from omop_emb.interface import EmbeddingInterface -from omop_emb.backends.config import IndexType, MetricType +from omop_emb.config import IndexType, MetricType from omop_emb.backends.base import NearestConceptMatch from .conftest import CONCEPTS, MODEL_NAME, EMBEDDING_DIM @@ -17,9 +17,8 @@ class TestInterface: def test_register_model(self, session, embedding_interface: EmbeddingInterface): """Test registering an embedding model.""" - model = embedding_interface.ensure_model_registered( + model = embedding_interface.register_model( engine=session.bind, - session=session, model_name=MODEL_NAME, dimensions=EMBEDDING_DIM, index_type=IndexType.FLAT, @@ -28,39 +27,36 @@ def test_register_model(self, session, embedding_interface: EmbeddingInterface): assert model.model_name == MODEL_NAME assert model.dimensions == EMBEDDING_DIM - def test_model_registration_idempotent(self, session, embedding_interface: EmbeddingInterface): + def test_model_registration_idempotent(self, session, embedding_interface: EmbeddingInterface, index_type: IndexType = IndexType.FLAT): """Test registering same model twice returns existing.""" - m1 = embedding_interface.ensure_model_registered( + m1 = embedding_interface.register_model( engine=session.bind, - session=session, model_name=MODEL_NAME, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, + index_type=index_type, ) - m2 = embedding_interface.ensure_model_registered( + m2 = embedding_interface.register_model( engine=session.bind, - session=session, model_name=MODEL_NAME, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, + index_type=index_type, ) assert m1.storage_identifier == m2.storage_identifier - def test_is_model_registered(self, session, embedding_interface: EmbeddingInterface): + def test_is_model_registered(self, session, embedding_interface: EmbeddingInterface, index_type: IndexType = IndexType.FLAT): """Test checking model registration status.""" - assert not embedding_interface.is_model_registered(session, MODEL_NAME) + assert not embedding_interface.is_model_registered(model_name=MODEL_NAME, index_type=index_type) - embedding_interface.ensure_model_registered( + embedding_interface.register_model( engine=session.bind, - session=session, model_name=MODEL_NAME, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, + index_type=index_type, ) - assert embedding_interface.is_model_registered(session, MODEL_NAME) + assert embedding_interface.is_model_registered(model_name=MODEL_NAME, index_type=index_type) def test_embed_texts(self, embedding_interface: EmbeddingInterface): """Test embedding generation.""" @@ -76,14 +72,13 @@ def test_embed_multiple_texts(self, embedding_interface: EmbeddingInterface): assert embeddings.shape == (len(texts), EMBEDDING_DIM) - def test_embed_and_upsert(self, session, embedding_interface: EmbeddingInterface): + def test_embed_and_upsert(self, session, embedding_interface: EmbeddingInterface, index_type: IndexType = IndexType.FLAT): """Test embedding and upserting concepts.""" - embedding_interface.ensure_model_registered( + embedding_interface.register_model( engine=session.bind, - session=session, model_name=MODEL_NAME, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, + index_type=index_type, ) test_concepts = [CONCEPTS["Hypertension"], CONCEPTS["Diabetes"]] @@ -95,19 +90,19 @@ def test_embed_and_upsert(self, session, embedding_interface: EmbeddingInterface model_name=MODEL_NAME, concept_ids=concept_ids, concept_texts=concept_texts, + index_type=index_type, ) assert embeddings.shape == (2, EMBEDDING_DIM) - assert embedding_interface.has_any_embeddings(session, MODEL_NAME) + assert embedding_interface.has_any_embeddings(session, embedding_model_name=MODEL_NAME, index_type=index_type) - def test_get_embeddings_by_concept_ids(self, session, embedding_interface: EmbeddingInterface): + def test_get_embeddings_by_concept_ids(self, session, embedding_interface: EmbeddingInterface, index_type: IndexType = IndexType.FLAT): """Test retrieving stored embeddings.""" - embedding_interface.ensure_model_registered( + embedding_interface.register_model( engine=session.bind, - session=session, model_name=MODEL_NAME, dimensions=EMBEDDING_DIM, - index_type=IndexType.FLAT, + index_type=index_type, ) test_concepts = [CONCEPTS["Hypertension"], CONCEPTS["Diabetes"]] @@ -117,6 +112,7 @@ def test_get_embeddings_by_concept_ids(self, session, embedding_interface: Embed embeddings = embedding_interface.embed_and_upsert_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, concept_ids=concept_ids, concept_texts=concept_texts, ) @@ -125,13 +121,14 @@ def test_get_embeddings_by_concept_ids(self, session, embedding_interface: Embed session=session, embedding_model_name=MODEL_NAME, concept_ids=concept_ids, + index_type=index_type, ) assert len(retrieved) == 2 assert 1 in retrieved and 2 in retrieved @pytest.mark.parametrize("backend_name", ["faiss", "pgvector"]) - def test_search_return_structure_for_backends(self, session, mock_llm_client, backend_name): + def test_search_return_structure_for_backends(self, session, mock_llm_client, backend_name, index_type: IndexType = IndexType.FLAT): """Search returns tuple[dict[int, float], ...] for both backend modes.""" mock_backend = Mock() mock_backend.get_nearest_concepts.return_value = ( @@ -162,6 +159,7 @@ def test_search_return_structure_for_backends(self, session, mock_llm_client, ba result = interface.get_nearest_concepts( session=session, model_name=MODEL_NAME, + index_type=index_type, query_embedding=query_embedding, metric_type=MetricType.COSINE, k=2, diff --git a/tests/test_interface_validation.py b/tests/test_interface_validation.py new file mode 100644 index 0000000..ec0d0c7 --- /dev/null +++ b/tests/test_interface_validation.py @@ -0,0 +1,51 @@ +"""Validation tests for strict EmbeddingInterface input contracts.""" + +from unittest.mock import Mock + +import numpy as np +import pytest + +from omop_emb.config import IndexType, MetricType +from omop_emb.interface import EmbeddingInterface + + +@pytest.mark.unit +class TestInterfaceValidation: + def test_get_nearest_concepts_requires_index_type(self): + """Index type is required as part of the strict core interface.""" + interface = EmbeddingInterface(backend=Mock(), embedding_client=Mock()) + kwargs = { + "session": Mock(), + "model_name": "test-model", + "query_embedding": np.zeros((1, 1), dtype=np.float32), + "metric_type": MetricType.COSINE, + } + + with pytest.raises(TypeError): + interface.get_nearest_concepts(**kwargs) + + def test_get_nearest_concepts_requires_metric_type(self): + """Metric type is required as part of the strict core interface.""" + interface = EmbeddingInterface(backend=Mock(), embedding_client=Mock()) + kwargs = { + "session": Mock(), + "model_name": "test-model", + "index_type": IndexType.FLAT, + "query_embedding": np.zeros((1, 1), dtype=np.float32), + } + + with pytest.raises(TypeError): + interface.get_nearest_concepts(**kwargs) + + def test_get_nearest_concepts_rejects_non_enum_metric_type(self): + """Core interface rejects non-MetricType values with a clear error.""" + interface = EmbeddingInterface(backend=Mock(), embedding_client=Mock()) + + with pytest.raises(TypeError, match="metric_type must be MetricType"): + interface.get_nearest_concepts( + session=Mock(), + model_name="test-model", + index_type=IndexType.FLAT, + query_embedding=np.zeros((1, 1), dtype=np.float32), + metric_type="cosine", # type: ignore[arg-type] + )