Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/imspy-predictors/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "imspy-predictors"
version = "0.5.0"
version = "0.5.1"
description = "ML-based predictors for CCS, retention time, and fragment intensity in mass spectrometry."
authors = [
{ name = "theGreatHerrLebert", email = "davidteschner@googlemail.com" }
Expand Down
26 changes: 23 additions & 3 deletions packages/imspy-predictors/src/imspy_predictors/pretrained/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,23 @@ def _download(url: str, dest: Path) -> None:
urllib.request.urlretrieve(url, str(dest))


def _find_bundled_model(model_name: str) -> Path | None:
"""Check for a model file co-located with this module (source / editable installs)."""
local = Path(__file__).resolve().parent / model_name
if local.is_file():
return local
return None


def ensure_model(model_name: str) -> Path:
"""Return the path to a cached model, downloading it first if necessary.

Resolution order:

1. Bundled alongside this module (editable / source installs).
2. Local cache (``$IMSPY_CACHE_DIR`` or ``~/.cache/imspy/models/``).
3. Download from GitHub Releases.

Parameters
----------
model_name : str
Expand All @@ -100,7 +114,7 @@ def ensure_model(model_name: str) -> Path:
Returns
-------
Path
Absolute path to the cached ``.pt`` file.
Absolute path to the ``.pt`` file.

Raises
------
Expand All @@ -115,11 +129,17 @@ def ensure_model(model_name: str) -> Path:
f"Known models: {sorted(MODELS)}"
)

# 1. Bundled copy (works for editable / source installs where .pt files
# live next to this module even though they are excluded from wheels).
bundled = _find_bundled_model(model_name)
if bundled is not None:
return bundled

meta = MODELS[model_name]
cache_dir = get_cache_dir()
cached_path = cache_dir / model_name

# Fast path: already cached and hash matches.
# 2. Fast path: already cached and hash matches.
if cached_path.exists():
digest = _sha256(cached_path)
if digest == meta["sha256"]:
Expand All @@ -131,7 +151,7 @@ def ensure_model(model_name: str) -> Path:
meta["sha256"],
)

# Download into a temp file in the same filesystem, then atomic-rename.
# 3. Download into a temp file in the same filesystem, then atomic-rename.
cached_path.parent.mkdir(parents=True, exist_ok=True)
url = f"{_RELEASE_BASE}/{meta['filename']}"
logger.info("Downloading model '%s' from %s ...", model_name, url)
Expand Down