Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 140 additions & 30 deletions packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- PyTorchCCSPredictor: PyTorch transformer-based model
"""

import warnings
from typing import List, Tuple, Optional, Union
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -88,6 +89,31 @@ def predict_inverse_ion_mobility(
ps.inverse_ion_mobility_predicted = mob


def _extract_prediction_mean(pred, key: Optional[str] = None):
"""Extract the mean tensor from legacy, tuple, and dict model outputs."""
if isinstance(pred, dict):
if key is None:
if len(pred) != 1:
raise KeyError(
"Cannot infer prediction key from multi-task model output."
)
pred = next(iter(pred.values()))
else:
pred = pred[key]
if isinstance(pred, tuple):
pred = pred[0]
return pred


def _extract_prediction_std(pred, key: Optional[str] = None):
"""Extract an uncertainty tensor when a model output exposes one."""
if isinstance(pred, dict):
pred = pred[key] if key is not None else next(iter(pred.values()))
if isinstance(pred, tuple) and len(pred) > 1:
return pred[1]
return None


def get_sqrt_slopes_and_intercepts(
mz: NDArray,
charge: NDArray,
Expand Down Expand Up @@ -459,60 +485,102 @@ def simulate_ion_mobilities(
return_uncertainty: If True, also return predicted uncertainty (std)

Returns:
Inverse ion mobilities (1/K0), or tuple of (1/K0, std) if return_uncertainty=True
Inverse ion mobilities (1/K0), or tuple of (1/K0, std) if
return_uncertainty=True. Charges outside the model's
``[1, _CHARGE_MAX]`` domain return ``NaN`` rather than
triggering a CUDA-side index assertion.
"""
self.model.eval()

# Prepare data
tokens = self._preprocess_sequences(sequences)
mz_tensor = torch.tensor(mz, dtype=torch.float32, device=self._device)
# The model's charge embedding has 4 slots (charges 1..4).
# Anything outside that range cannot be one-hot encoded without
# tripping F.one_hot's index assertion (silent on CPU, hard
# CUDA crash on GPU). Filter such rows out, predict on the
# rest, and return NaN at their positions so downstream sees
# "no prediction" instead of a process-killing assert.
n_total = len(sequences)
charges_arr = np.asarray(charges, dtype=np.int64)
valid_mask = (charges_arr >= 1) & (charges_arr <= 4)
n_invalid = int((~valid_mask).sum())
if n_invalid:
warnings.warn(
f"simulate_ion_mobilities: {n_invalid} of {n_total} input "
f"PSMs have charge outside [1, 4]; returning NaN for them.",
RuntimeWarning,
stacklevel=2,
)

valid_idx = np.flatnonzero(valid_mask)
inverse_mobility = np.full(n_total, np.nan, dtype=np.float64)
inverse_mobility_std = (
np.full(n_total, np.nan, dtype=np.float64) if return_uncertainty else None
)

if valid_idx.size == 0:
return (
(inverse_mobility, inverse_mobility_std)
if return_uncertainty
else inverse_mobility
)

valid_sequences = [sequences[i] for i in valid_idx]
valid_charges = charges_arr[valid_idx].tolist()
valid_mz = [mz[i] for i in valid_idx]

# Prepare data on the filtered subset.
tokens = self._preprocess_sequences(valid_sequences)
mz_tensor = torch.tensor(valid_mz, dtype=torch.float32, device=self._device)
charges_onehot = F.one_hot(
torch.tensor(charges, device=self._device) - 1, num_classes=4
torch.tensor(valid_charges, dtype=torch.long, device=self._device) - 1,
num_classes=4,
).float()

# Predict in batches
all_ccs = []
all_ccs_std = []
with torch.no_grad():
for i in range(0, len(sequences), batch_size):
for i in range(0, len(valid_sequences), batch_size):
batch_tokens = tokens[i:i + batch_size]
batch_mz = mz_tensor[i:i + batch_size]
batch_charges = charges_onehot[i:i + batch_size]

# Handle different model types
if hasattr(self.model, 'predict_ccs'):
# UnifiedPeptideModel
ccs, ccs_std = self.model.predict_ccs(
ccs_pred = self.model.predict_ccs(
batch_tokens,
batch_mz,
torch.argmax(batch_charges, dim=1) + 1,
)
ccs = _extract_prediction_mean(ccs_pred, "ccs")
ccs_std = _extract_prediction_std(ccs_pred, "ccs")
else:
# Legacy PyTorchCCSPredictor
result = self.model(batch_mz, batch_charges, batch_tokens)
if isinstance(result, tuple):
ccs, ccs_std = result[0], result[1] if len(result) > 1 else None
else:
ccs, ccs_std = result, None
ccs = _extract_prediction_mean(result, "ccs")
ccs_std = _extract_prediction_std(result, "ccs")

all_ccs.append(ccs.cpu().numpy())
if ccs_std is not None:
all_ccs_std.append(ccs_std.cpu().numpy())

ccs = np.concatenate(all_ccs, axis=0).flatten()

# Convert CCS to inverse mobility
inverse_mobility = np.array([
# Convert CCS to inverse mobility on the valid subset, then
# scatter back into the full output array.
valid_inverse_mobility = np.array([
ccs_to_one_over_k0(c, m, z)
for c, m, z in zip(ccs, mz, charges)
for c, m, z in zip(ccs, valid_mz, valid_charges)
])
inverse_mobility[valid_idx] = valid_inverse_mobility

if return_uncertainty and all_ccs_std:
ccs_std = np.concatenate(all_ccs_std, axis=0).flatten()
inverse_mobility_std = np.array([
valid_inverse_mobility_std = np.array([
ccs_to_one_over_k0(s, m, z)
for s, m, z in zip(ccs_std, mz, charges)
for s, m, z in zip(ccs_std, valid_mz, valid_charges)
])
inverse_mobility_std[valid_idx] = valid_inverse_mobility_std
return inverse_mobility, inverse_mobility_std

return inverse_mobility
Expand Down Expand Up @@ -546,8 +614,14 @@ def fine_tune_model(
assert 'calcmass' in data.columns, 'Data must contain column "calcmass"'
assert 'ims' in data.columns, 'Data must contain column "ims"'

mz = [calculate_mz(m, z) for m, z in zip(data.calcmass.values, data.charge.values.astype(np.int32))]
charges = data.charge.values.astype(np.int32)
mz = [
calculate_mz(m, z)
for m, z in zip(
data.calcmass.values,
data.charge.values.astype(np.int64),
)
]
charges = data.charge.values.astype(np.int64)

if decoys_separate:
sequences = []
Expand All @@ -559,7 +633,37 @@ def fine_tune_model(
else:
sequences = list(data.sequence_modified.values)

inv_mob = data.ims.values
# Drop training rows whose charge is outside the model's
# one-hot domain (1..4). They cannot be one-hot encoded without
# tripping F.one_hot's CUDA index assertion. Charges of 5+ leak
# through sage matching at ~0.15% on HeLa-class data even when
# precursor_charge is configured as [2, 4].
n_total = len(sequences)
valid_mask = (charges >= 1) & (charges <= 4)
n_invalid = int((~valid_mask).sum())
if n_invalid:
warnings.warn(
f"fine_tune_model: dropping {n_invalid} of {n_total} training "
f"PSMs with charge outside [1, 4].",
RuntimeWarning,
stacklevel=2,
)
sequences = [s for s, m in zip(sequences, valid_mask) if m]
mz = [m for m, ok in zip(mz, valid_mask) if ok]
charges = charges[valid_mask]
inv_mob = data.ims.values[valid_mask]
else:
inv_mob = data.ims.values

if len(sequences) == 0:
warnings.warn(
"fine_tune_model: no PSMs in valid charge range; skipping fine-tune.",
RuntimeWarning,
stacklevel=2,
)
self._finetune_history = {"epochs": [], "train_loss": [], "val_loss": []}
return

ccs = np.array([
one_over_k0_to_ccs(i, m, z)
for i, m, z in zip(inv_mob, mz, charges)
Expand All @@ -569,7 +673,8 @@ def fine_tune_model(
tokens = self._preprocess_sequences(sequences)
mz_tensor = torch.tensor(mz, dtype=torch.float32, device=self._device)
charges_onehot = F.one_hot(
torch.tensor(charges, device=self._device) - 1, num_classes=4
torch.tensor(charges, dtype=torch.long, device=self._device) - 1,
num_classes=4,
).float()
ccs_tensor = torch.tensor(ccs, dtype=torch.float32, device=self._device).unsqueeze(1)

Expand Down Expand Up @@ -606,6 +711,7 @@ def fine_tune_model(
)
checkpoint = InMemoryCheckpoint(patience=patience)

history = {"epochs": [], "train_loss": [], "val_loss": []}
for epoch in range(epochs):
# Training
self.model.train()
Expand All @@ -616,20 +722,20 @@ def fine_tune_model(

# Handle different model types
if hasattr(self.model, 'predict_ccs'):
pred, _ = self.model.predict_ccs(
pred = self.model.predict_ccs(
tokens_b,
mz_b,
torch.argmax(charge_b, dim=1) + 1,
)
else:
pred, _ = self.model(mz_b, charge_b, tokens_b)
if isinstance(pred, tuple):
pred = pred[0]
pred = self.model(mz_b, charge_b, tokens_b)
pred = _extract_prediction_mean(pred, "ccs")

loss = F.l1_loss(pred, ccs_b)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss /= max(len(train_loader), 1)

# Validation
self.model.eval()
Expand All @@ -639,23 +745,26 @@ def fine_tune_model(
mz_b, charge_b, tokens_b, ccs_b = batch

if hasattr(self.model, 'predict_ccs'):
pred, _ = self.model.predict_ccs(
pred = self.model.predict_ccs(
tokens_b,
mz_b,
torch.argmax(charge_b, dim=1) + 1,
)
else:
pred, _ = self.model(mz_b, charge_b, tokens_b)
if isinstance(pred, tuple):
pred = pred[0]
pred = self.model(mz_b, charge_b, tokens_b)
pred = _extract_prediction_mean(pred, "ccs")

val_loss += F.l1_loss(pred, ccs_b).item()

val_loss /= len(val_loader)
val_loss /= max(len(val_loader), 1)
scheduler.step(val_loss)

history["epochs"].append(epoch)
history["train_loss"].append(float(train_loss))
history["val_loss"].append(float(val_loss))

if verbose and epoch % 10 == 0:
print(f"Epoch {epoch}: val_loss={val_loss:.4f}")
print(f"Epoch {epoch}: train_loss={train_loss:.4f} val_loss={val_loss:.4f}")

if checkpoint.step(val_loss, self.model):
if verbose:
Expand All @@ -664,6 +773,7 @@ def fine_tune_model(

checkpoint.restore(self.model)
self.model.eval()
self._finetune_history = history

def simulate_ion_mobilities_pandas(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from imspy_predictors.intensity.predictors import (
IonIntensityPredictor,
Prosit2023TimsTofWrapper,
calibrate_nce,
get_collision_energy_calibration_factor,
remove_unimod_annotation,
predict_fragment_intensities_with_koina,
Expand All @@ -25,6 +26,7 @@
'IonIntensityPredictor',
'Prosit2023TimsTofWrapper',
# Utilities
'calibrate_nce',
'get_collision_energy_calibration_factor',
'remove_unimod_annotation',
'post_process_predicted_fragment_spectra',
Expand Down
Loading
Loading