diff --git a/packages/imspy-predictors/pyproject.toml b/packages/imspy-predictors/pyproject.toml index e23eae3d..408bbcc4 100644 --- a/packages/imspy-predictors/pyproject.toml +++ b/packages/imspy-predictors/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "imspy-predictors" -version = "0.5.0" +version = "0.5.1" description = "ML-based predictors for CCS, retention time, and fragment intensity in mass spectrometry." authors = [ { name = "theGreatHerrLebert", email = "davidteschner@googlemail.com" } diff --git a/packages/imspy-predictors/src/imspy_predictors/pretrained/hub.py b/packages/imspy-predictors/src/imspy_predictors/pretrained/hub.py index 25d9ae45..346d7962 100644 --- a/packages/imspy-predictors/src/imspy_predictors/pretrained/hub.py +++ b/packages/imspy-predictors/src/imspy_predictors/pretrained/hub.py @@ -89,9 +89,23 @@ def _download(url: str, dest: Path) -> None: urllib.request.urlretrieve(url, str(dest)) +def _find_bundled_model(model_name: str) -> Path | None: + """Check for a model file co-located with this module (source / editable installs).""" + local = Path(__file__).resolve().parent / model_name + if local.is_file(): + return local + return None + + def ensure_model(model_name: str) -> Path: """Return the path to a cached model, downloading it first if necessary. + Resolution order: + + 1. Bundled alongside this module (editable / source installs). + 2. Local cache (``$IMSPY_CACHE_DIR`` or ``~/.cache/imspy/models/``). + 3. Download from GitHub Releases. + Parameters ---------- model_name : str @@ -100,7 +114,7 @@ def ensure_model(model_name: str) -> Path: Returns ------- Path - Absolute path to the cached ``.pt`` file. + Absolute path to the ``.pt`` file. Raises ------ @@ -115,11 +129,17 @@ def ensure_model(model_name: str) -> Path: f"Known models: {sorted(MODELS)}" ) + # 1. Bundled copy (works for editable / source installs where .pt files + # live next to this module even though they are excluded from wheels). + bundled = _find_bundled_model(model_name) + if bundled is not None: + return bundled + meta = MODELS[model_name] cache_dir = get_cache_dir() cached_path = cache_dir / model_name - # Fast path: already cached and hash matches. + # 2. Fast path: already cached and hash matches. if cached_path.exists(): digest = _sha256(cached_path) if digest == meta["sha256"]: @@ -131,7 +151,7 @@ def ensure_model(model_name: str) -> Path: meta["sha256"], ) - # Download into a temp file in the same filesystem, then atomic-rename. + # 3. Download into a temp file in the same filesystem, then atomic-rename. cached_path.parent.mkdir(parents=True, exist_ok=True) url = f"{_RELEASE_BASE}/{meta['filename']}" logger.info("Downloading model '%s' from %s ...", model_name, url)