From e9c6c18e8612e14f6f87ddcf8733fe26dca6c6fb Mon Sep 17 00:00:00 2001 From: theGreatHerrLebert Date: Wed, 6 May 2026 17:30:42 +0200 Subject: [PATCH 1/5] Fix imspy predictor fine-tuning outputs --- .../src/imspy_predictors/ccs/predictors.py | 67 ++++++++++++++----- .../src/imspy_predictors/rt/predictors.py | 22 +++++- .../imspy-predictors/tests/test_predictors.py | 26 +++++++ 3 files changed, 94 insertions(+), 21 deletions(-) diff --git a/packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py b/packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py index 24ab6230..16d1dfae 100644 --- a/packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py +++ b/packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py @@ -88,6 +88,31 @@ def predict_inverse_ion_mobility( ps.inverse_ion_mobility_predicted = mob +def _extract_prediction_mean(pred, key: Optional[str] = None): + """Extract the mean tensor from legacy, tuple, and dict model outputs.""" + if isinstance(pred, dict): + if key is None: + if len(pred) != 1: + raise KeyError( + "Cannot infer prediction key from multi-task model output." + ) + pred = next(iter(pred.values())) + else: + pred = pred[key] + if isinstance(pred, tuple): + pred = pred[0] + return pred + + +def _extract_prediction_std(pred, key: Optional[str] = None): + """Extract an uncertainty tensor when a model output exposes one.""" + if isinstance(pred, dict): + pred = pred[key] if key is not None else next(iter(pred.values())) + if isinstance(pred, tuple) and len(pred) > 1: + return pred[1] + return None + + def get_sqrt_slopes_and_intercepts( mz: NDArray, charge: NDArray, @@ -467,7 +492,8 @@ def simulate_ion_mobilities( tokens = self._preprocess_sequences(sequences) mz_tensor = torch.tensor(mz, dtype=torch.float32, device=self._device) charges_onehot = F.one_hot( - torch.tensor(charges, device=self._device) - 1, num_classes=4 + torch.tensor(charges, dtype=torch.long, device=self._device) - 1, + num_classes=4, ).float() # Predict in batches @@ -482,18 +508,18 @@ def simulate_ion_mobilities( # Handle different model types if hasattr(self.model, 'predict_ccs'): # UnifiedPeptideModel - ccs, ccs_std = self.model.predict_ccs( + ccs_pred = self.model.predict_ccs( batch_tokens, batch_mz, torch.argmax(batch_charges, dim=1) + 1, ) + ccs = _extract_prediction_mean(ccs_pred, "ccs") + ccs_std = _extract_prediction_std(ccs_pred, "ccs") else: # Legacy PyTorchCCSPredictor result = self.model(batch_mz, batch_charges, batch_tokens) - if isinstance(result, tuple): - ccs, ccs_std = result[0], result[1] if len(result) > 1 else None - else: - ccs, ccs_std = result, None + ccs = _extract_prediction_mean(result, "ccs") + ccs_std = _extract_prediction_std(result, "ccs") all_ccs.append(ccs.cpu().numpy()) if ccs_std is not None: @@ -546,8 +572,14 @@ def fine_tune_model( assert 'calcmass' in data.columns, 'Data must contain column "calcmass"' assert 'ims' in data.columns, 'Data must contain column "ims"' - mz = [calculate_mz(m, z) for m, z in zip(data.calcmass.values, data.charge.values.astype(np.int32))] - charges = data.charge.values.astype(np.int32) + mz = [ + calculate_mz(m, z) + for m, z in zip( + data.calcmass.values, + data.charge.values.astype(np.int64), + ) + ] + charges = data.charge.values.astype(np.int64) if decoys_separate: sequences = [] @@ -569,7 +601,8 @@ def fine_tune_model( tokens = self._preprocess_sequences(sequences) mz_tensor = torch.tensor(mz, dtype=torch.float32, device=self._device) charges_onehot = F.one_hot( - torch.tensor(charges, device=self._device) - 1, num_classes=4 + torch.tensor(charges, dtype=torch.long, device=self._device) - 1, + num_classes=4, ).float() ccs_tensor = torch.tensor(ccs, dtype=torch.float32, device=self._device).unsqueeze(1) @@ -616,15 +649,14 @@ def fine_tune_model( # Handle different model types if hasattr(self.model, 'predict_ccs'): - pred, _ = self.model.predict_ccs( + pred = self.model.predict_ccs( tokens_b, mz_b, torch.argmax(charge_b, dim=1) + 1, ) else: - pred, _ = self.model(mz_b, charge_b, tokens_b) - if isinstance(pred, tuple): - pred = pred[0] + pred = self.model(mz_b, charge_b, tokens_b) + pred = _extract_prediction_mean(pred, "ccs") loss = F.l1_loss(pred, ccs_b) loss.backward() @@ -639,19 +671,18 @@ def fine_tune_model( mz_b, charge_b, tokens_b, ccs_b = batch if hasattr(self.model, 'predict_ccs'): - pred, _ = self.model.predict_ccs( + pred = self.model.predict_ccs( tokens_b, mz_b, torch.argmax(charge_b, dim=1) + 1, ) else: - pred, _ = self.model(mz_b, charge_b, tokens_b) - if isinstance(pred, tuple): - pred = pred[0] + pred = self.model(mz_b, charge_b, tokens_b) + pred = _extract_prediction_mean(pred, "ccs") val_loss += F.l1_loss(pred, ccs_b).item() - val_loss /= len(val_loader) + val_loss /= max(len(val_loader), 1) scheduler.step(val_loss) if verbose and epoch % 10 == 0: diff --git a/packages/imspy-predictors/src/imspy_predictors/rt/predictors.py b/packages/imspy-predictors/src/imspy_predictors/rt/predictors.py index b5f09a2b..9f255716 100644 --- a/packages/imspy-predictors/src/imspy_predictors/rt/predictors.py +++ b/packages/imspy-predictors/src/imspy_predictors/rt/predictors.py @@ -96,6 +96,22 @@ def predict_retention_time( ps.retention_time_predicted = rt +def _extract_prediction_mean(pred, key: Optional[str] = None): + """Extract the mean tensor from legacy, tuple, and dict model outputs.""" + if isinstance(pred, dict): + if key is None: + if len(pred) != 1: + raise KeyError( + "Cannot infer prediction key from multi-task model output." + ) + pred = next(iter(pred.values())) + else: + pred = pred[key] + if isinstance(pred, tuple): + pred = pred[0] + return pred + + class PeptideChromatographyApex(ABC): """Abstract base class for chromatographic separation prediction.""" @@ -421,7 +437,7 @@ def _fine_tune( self.model.train() for tokens_b, rt_b in train_loader: optimizer.zero_grad() - pred = self.model(tokens_b) + pred = _extract_prediction_mean(self.model(tokens_b), "rt") loss = F.l1_loss(pred, rt_b) loss.backward() optimizer.step() @@ -430,9 +446,9 @@ def _fine_tune( val_loss = 0 with torch.no_grad(): for tokens_b, rt_b in val_loader: - pred = self.model(tokens_b) + pred = _extract_prediction_mean(self.model(tokens_b), "rt") val_loss += F.l1_loss(pred, rt_b).item() - val_loss /= len(val_loader) + val_loss /= max(len(val_loader), 1) scheduler.step(val_loss) if verbose and epoch % 10 == 0: diff --git a/packages/imspy-predictors/tests/test_predictors.py b/packages/imspy-predictors/tests/test_predictors.py index bb1a25e5..df08538a 100644 --- a/packages/imspy-predictors/tests/test_predictors.py +++ b/packages/imspy-predictors/tests/test_predictors.py @@ -51,6 +51,22 @@ def test_sqrt_slopes_intercepts(self): assert len(slopes) == 4 # Charges 1-4 (with 0 for charge 1) assert len(intercepts) == 4 + def test_ccs_extract_prediction_mean(self): + """Test CCS fine-tune helper handles dict and tuple outputs.""" + from imspy_predictors.ccs.predictors import ( + _extract_prediction_mean, + _extract_prediction_std, + ) + + mean = torch.ones(3, 1) + std = torch.full((3, 1), 0.1) + + assert _extract_prediction_mean((mean, std)) is mean + assert _extract_prediction_mean({"ccs": (mean, std)}, "ccs") is mean + assert _extract_prediction_mean({"ccs": mean}, "ccs") is mean + assert _extract_prediction_std((mean, std)) is std + assert _extract_prediction_std({"ccs": (mean, std)}, "ccs") is std + class TestRTPredictor: """Test suite for retention time predictors.""" @@ -75,6 +91,16 @@ def test_pytorch_rt_forward(self, rt_model): assert rt.shape == (batch_size, 1) + def test_rt_extract_prediction_mean(self): + """Test RT fine-tune helper handles dict and tuple outputs.""" + from imspy_predictors.rt.predictors import _extract_prediction_mean + + mean = torch.ones(3, 1) + + assert _extract_prediction_mean(mean, "rt") is mean + assert _extract_prediction_mean((mean,), "rt") is mean + assert _extract_prediction_mean({"rt": mean}, "rt") is mean + class TestChargePredictor: """Test suite for charge state predictors.""" From f7ac88051ec000d53f76d6e8e5b0c4b6f1a47232 Mon Sep 17 00:00:00 2001 From: theGreatHerrLebert Date: Wed, 6 May 2026 17:50:30 +0200 Subject: [PATCH 2/5] Add native intensity fine-tuning --- .../imspy_predictors/intensity/predictors.py | 228 ++++++++++++++++++ .../imspy-predictors/tests/test_predictors.py | 23 ++ 2 files changed, 251 insertions(+) diff --git a/packages/imspy-predictors/src/imspy_predictors/intensity/predictors.py b/packages/imspy-predictors/src/imspy_predictors/intensity/predictors.py index e422f1a7..a4d84493 100644 --- a/packages/imspy-predictors/src/imspy_predictors/intensity/predictors.py +++ b/packages/imspy-predictors/src/imspy_predictors/intensity/predictors.py @@ -26,6 +26,8 @@ from imspy_core.data import PeptideProductIonSeriesCollection, PeptideSequence from imspy_core.utility import remove_unimod_annotation +from imspy_predictors.losses import masked_spectral_distance +from imspy_predictors.utility import InMemoryCheckpoint # Lazy imports for optional dependencies from imspy_predictors.lazy_imports import ( @@ -34,6 +36,69 @@ ) +def _ion_to_text(ion) -> str: + text = str(ion).lower() + if text in ("b", "y"): + return text + if text in ("iontype(b)", "iontype(y)"): + return text[-2] + if text.endswith(" b") or text.endswith(".b"): + return "b" + if text.endswith(" y") or text.endswith(".y"): + return "y" + return text[-1:] + + +def observed_fragments_to_intensity_target( + sequence: str, + precursor_charge: int, + fragments, +) -> np.ndarray: + """Build a native Prosit-layout target vector from observed Sage fragments. + + Output layout is ordinal-major: + [y1+1, y1+2, y1+3, b1+1, b1+2, b1+3, y2+1, ...]. + Impossible fragments are marked -1 for masked spectral loss; valid but + unmatched fragments remain zero. + """ + sequence_length = len(remove_unimod_annotation(sequence)) + target = np.zeros(174, dtype=np.float32) + + target_3d = target.reshape(29, 6) + max_frag_pos = max(sequence_length - 1, 0) + if max_frag_pos < 29: + target_3d[max_frag_pos:, :] = -1.0 + for ion_charge in range(1, 4): + if ion_charge > int(precursor_charge): + target_3d[:, ion_charge - 1] = -1.0 + target_3d[:, ion_charge + 2] = -1.0 + + intensities = np.asarray(fragments.intensities, dtype=np.float32) + if intensities.size == 0: + return target + max_intensity = float(np.max(intensities)) + if max_intensity <= 0: + return target + + for ion_type, ordinal, charge, intensity in zip( + fragments.ion_types, + fragments.fragment_ordinals, + fragments.charges, + intensities, + ): + ion = _ion_to_text(ion_type) + ordinal = int(ordinal) + charge = int(charge) + if ion not in ("b", "y"): + continue + if not (1 <= ordinal <= 29 and 1 <= charge <= 3): + continue + slot = (ordinal - 1) * 6 + (charge - 1 if ion == "y" else 3 + charge - 1) + if target[slot] >= 0: + target[slot] = float(intensity) / max_intensity + return target + + def predict_intensities_prosit( psm_collection: List, calibrate_collision_energy: bool = True, @@ -757,6 +822,169 @@ def _predict_batch( return np.vstack(all_intensities) + def fine_tune_model( + self, + data: pd.DataFrame, + batch_size: int = 64, + epochs: int = 50, + learning_rate: float = 1e-4, + patience: int = 5, + divide_collision_energy_by: float = 1e2, + verbose: bool = False, + ) -> None: + """ + Fine-tune the native intensity model on observed 174-vector targets. + + Args: + data: DataFrame with columns: sequence, charge, collision_energy, + intensity_target. The target must use the native ordinal-major + 174-vector layout and mark impossible ions as -1. + batch_size: Training batch size + epochs: Maximum number of epochs + learning_rate: Learning rate + patience: Early stopping patience + divide_collision_energy_by: CE normalization factor + verbose: Whether to print progress + """ + assert 'sequence' in data.columns, 'Data must contain column "sequence"' + assert 'charge' in data.columns, 'Data must contain column "charge"' + assert 'collision_energy' in data.columns, 'Data must contain column "collision_energy"' + assert 'intensity_target' in data.columns, 'Data must contain column "intensity_target"' + + from torch.utils.data import DataLoader, TensorDataset + + if len(data) < 2: + if verbose: + print("Skipping intensity fine-tune: need at least two PSMs") + return + + sequences = data.sequence.tolist() + charges = data.charge.astype(np.int64).tolist() + collision_energies = (data.collision_energy.astype(float) / divide_collision_energy_by).tolist() + targets = np.vstack(data.intensity_target.to_numpy()).astype(np.float32) + if targets.shape != (len(data), 174): + raise ValueError(f"intensity_target must have shape (n, 174), got {targets.shape}") + + tokens, charge_tensor, ce_tensor = self._preprocess( + sequences, + charges, + collision_energies, + ) + tokens = tokens.to(self._device) + charge_tensor = charge_tensor.to(self._device) + ce_tensor = ce_tensor.to(self._device) + target_tensor = self._torch.tensor(targets, dtype=self._torch.float32, device=self._device) + + n = len(sequences) + n_train = max(1, int(0.8 * n)) + if n_train >= n: + n_train = n - 1 + indices = self._torch.randperm(n, device=self._device) + train_idx = indices[:n_train] + val_idx = indices[n_train:] + + train_dataset = TensorDataset( + tokens[train_idx], + charge_tensor[train_idx], + ce_tensor[train_idx], + target_tensor[train_idx], + ) + val_dataset = TensorDataset( + tokens[val_idx], + charge_tensor[val_idx], + ce_tensor[val_idx], + target_tensor[val_idx], + ) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size) + + self.model.train() + optimizer = self._torch.optim.Adam(self.model.parameters(), lr=learning_rate) + scheduler = self._torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, patience=3, min_lr=1e-6 + ) + checkpoint = InMemoryCheckpoint(patience=patience) + + for epoch in range(epochs): + self.model.train() + for tokens_b, charge_b, ce_b, target_b in train_loader: + optimizer.zero_grad() + outputs = self.model( + tokens_b, + charge=charge_b, + collision_energy=ce_b, + ) + pred = outputs['intensity'] if 'intensity' in outputs else list(outputs.values())[0] + loss = masked_spectral_distance(target_b, pred) + loss.backward() + optimizer.step() + + self.model.eval() + val_loss = 0.0 + with self._torch.no_grad(): + for tokens_b, charge_b, ce_b, target_b in val_loader: + outputs = self.model( + tokens_b, + charge=charge_b, + collision_energy=ce_b, + ) + pred = outputs['intensity'] if 'intensity' in outputs else list(outputs.values())[0] + val_loss += masked_spectral_distance(target_b, pred).item() + val_loss /= max(len(val_loader), 1) + scheduler.step(val_loss) + + if verbose and epoch % 5 == 0: + print(f"Epoch {epoch}: intensity_val_loss={val_loss:.4f}") + + if checkpoint.step(val_loss, self.model): + if verbose: + print(f"Early stopping intensity fine-tune at epoch {epoch}") + break + + checkpoint.restore(self.model) + self.model.eval() + + def fine_tune_psms( + self, + psm_collection: List, + batch_size: int = 64, + epochs: int = 50, + learning_rate: float = 1e-4, + patience: int = 5, + verbose: bool = False, + ) -> None: + """Fine-tune the native intensity model from Sage PSM observed fragments.""" + rows = [] + for psm in psm_collection: + sequence = psm.sequence_modified if not psm.decoy else psm.sequence_decoy_modified + target = observed_fragments_to_intensity_target( + sequence, + psm.charge, + psm.sage_feature.fragments, + ) + if np.any(target > 0): + rows.append({ + 'sequence': sequence, + 'charge': int(psm.charge), + 'collision_energy': float(psm.collision_energy), + 'intensity_target': target, + }) + + if len(rows) < 2: + if verbose: + print("Skipping intensity fine-tune: no usable fragment targets") + return + + self.fine_tune_model( + pd.DataFrame(rows), + batch_size=batch_size, + epochs=epochs, + learning_rate=learning_rate, + patience=patience, + verbose=verbose, + ) + def predict_intensities( self, sequences: List[str], diff --git a/packages/imspy-predictors/tests/test_predictors.py b/packages/imspy-predictors/tests/test_predictors.py index df08538a..15d24e24 100644 --- a/packages/imspy-predictors/tests/test_predictors.py +++ b/packages/imspy-predictors/tests/test_predictors.py @@ -128,6 +128,29 @@ def test_pytorch_charge_forward(self, charge_model): assert torch.allclose(probs.sum(dim=1), torch.ones(batch_size), atol=1e-5) +class TestIntensityPredictor: + """Test suite for native intensity helpers.""" + + def test_observed_fragments_to_intensity_target(self): + from imspy_predictors.intensity.predictors import ( + observed_fragments_to_intensity_target, + ) + + class Fragments: + ion_types = ["IonType(Y)", "IonType(B)", "IonType(Y)"] + fragment_ordinals = [1, 2, 4] + charges = [1, 2, 3] + intensities = [10.0, 5.0, 1.0] + + target = observed_fragments_to_intensity_target("PEPTIDE", 2, Fragments()) + + assert target.shape == (174,) + assert target[0] == 1.0 + assert target[(2 - 1) * 6 + 4] == 0.5 + assert target[(4 - 1) * 6 + 2] == -1.0 + assert np.all(target[(len("PEPTIDE") - 1) * 6:] == -1.0) + + class TestBinomialChargeModel: """Test suite for binomial charge state model.""" From 86857fba25dc0a7cf8a3deae70df5902f692f1fc Mon Sep 17 00:00:00 2001 From: theGreatHerrLebert Date: Wed, 6 May 2026 18:22:15 +0200 Subject: [PATCH 3/5] Filter out-of-range charges in IM simulate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit simulate_ion_mobilities builds a charge one-hot with num_classes=4 (charges 1..4). Inputs outside that range trigger F.one_hot's index assertion — silent on CPU, hard CUDA crash on GPU (ScatterGatherKernel idx_dim < index_size). Filter such rows before the one_hot, predict on the valid subset, and NaN-pad invalid positions in the output array. Emit a RuntimeWarning so callers see how many were skipped. Charges of 5+ leak through sage matching at ~0.15% on HeLa data even with precursor_charge=[2,4], which is enough to take down a GPU run. --- .../src/imspy_predictors/ccs/predictors.py | 64 +++++++++++++++---- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py b/packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py index 16d1dfae..0683a6c0 100644 --- a/packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py +++ b/packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py @@ -9,6 +9,7 @@ - PyTorchCCSPredictor: PyTorch transformer-based model """ +import warnings from typing import List, Tuple, Optional, Union from abc import ABC, abstractmethod @@ -484,15 +485,53 @@ def simulate_ion_mobilities( return_uncertainty: If True, also return predicted uncertainty (std) Returns: - Inverse ion mobilities (1/K0), or tuple of (1/K0, std) if return_uncertainty=True + Inverse ion mobilities (1/K0), or tuple of (1/K0, std) if + return_uncertainty=True. Charges outside the model's + ``[1, _CHARGE_MAX]`` domain return ``NaN`` rather than + triggering a CUDA-side index assertion. """ self.model.eval() - # Prepare data - tokens = self._preprocess_sequences(sequences) - mz_tensor = torch.tensor(mz, dtype=torch.float32, device=self._device) + # The model's charge embedding has 4 slots (charges 1..4). + # Anything outside that range cannot be one-hot encoded without + # tripping F.one_hot's index assertion (silent on CPU, hard + # CUDA crash on GPU). Filter such rows out, predict on the + # rest, and return NaN at their positions so downstream sees + # "no prediction" instead of a process-killing assert. + n_total = len(sequences) + charges_arr = np.asarray(charges, dtype=np.int64) + valid_mask = (charges_arr >= 1) & (charges_arr <= 4) + n_invalid = int((~valid_mask).sum()) + if n_invalid: + warnings.warn( + f"simulate_ion_mobilities: {n_invalid} of {n_total} input " + f"PSMs have charge outside [1, 4]; returning NaN for them.", + RuntimeWarning, + stacklevel=2, + ) + + valid_idx = np.flatnonzero(valid_mask) + inverse_mobility = np.full(n_total, np.nan, dtype=np.float64) + inverse_mobility_std = ( + np.full(n_total, np.nan, dtype=np.float64) if return_uncertainty else None + ) + + if valid_idx.size == 0: + return ( + (inverse_mobility, inverse_mobility_std) + if return_uncertainty + else inverse_mobility + ) + + valid_sequences = [sequences[i] for i in valid_idx] + valid_charges = charges_arr[valid_idx].tolist() + valid_mz = [mz[i] for i in valid_idx] + + # Prepare data on the filtered subset. + tokens = self._preprocess_sequences(valid_sequences) + mz_tensor = torch.tensor(valid_mz, dtype=torch.float32, device=self._device) charges_onehot = F.one_hot( - torch.tensor(charges, dtype=torch.long, device=self._device) - 1, + torch.tensor(valid_charges, dtype=torch.long, device=self._device) - 1, num_classes=4, ).float() @@ -500,7 +539,7 @@ def simulate_ion_mobilities( all_ccs = [] all_ccs_std = [] with torch.no_grad(): - for i in range(0, len(sequences), batch_size): + for i in range(0, len(valid_sequences), batch_size): batch_tokens = tokens[i:i + batch_size] batch_mz = mz_tensor[i:i + batch_size] batch_charges = charges_onehot[i:i + batch_size] @@ -527,18 +566,21 @@ def simulate_ion_mobilities( ccs = np.concatenate(all_ccs, axis=0).flatten() - # Convert CCS to inverse mobility - inverse_mobility = np.array([ + # Convert CCS to inverse mobility on the valid subset, then + # scatter back into the full output array. + valid_inverse_mobility = np.array([ ccs_to_one_over_k0(c, m, z) - for c, m, z in zip(ccs, mz, charges) + for c, m, z in zip(ccs, valid_mz, valid_charges) ]) + inverse_mobility[valid_idx] = valid_inverse_mobility if return_uncertainty and all_ccs_std: ccs_std = np.concatenate(all_ccs_std, axis=0).flatten() - inverse_mobility_std = np.array([ + valid_inverse_mobility_std = np.array([ ccs_to_one_over_k0(s, m, z) - for s, m, z in zip(ccs_std, mz, charges) + for s, m, z in zip(ccs_std, valid_mz, valid_charges) ]) + inverse_mobility_std[valid_idx] = valid_inverse_mobility_std return inverse_mobility, inverse_mobility_std return inverse_mobility From 9981465cd4bbf17e847d0337a0f9538d13d48bef Mon Sep 17 00:00:00 2001 From: theGreatHerrLebert Date: Wed, 6 May 2026 18:29:51 +0200 Subject: [PATCH 4/5] Capture per-epoch fine-tune history; harden IM fine-tune charge filter Three small additions to the fine-tune training loops in imspy_predictors: - rt/predictors.py: also accumulate train_loss across train batches, store {epochs, train_loss, val_loss} on self._finetune_history. - ccs/predictors.py: same for CCS/IM fine_tune_model. Also drop training rows whose charge is outside the model's [1, 4] one-hot domain before constructing the charge tensor; same root cause as the earlier simulate_ion_mobilities filter (CUDA assertion on charge=5+ PSMs that leak through sage matching). - intensity/predictors.py: same train_loss accumulation and history capture for the native intensity fine-tune loop. The history dict is the shape sagepy-rescore's report.py expects so it can render per-head loss curves and improvement-vs-epoch-0 panels for the sagepy-rescore HTML report. --- .../src/imspy_predictors/ccs/predictors.py | 41 ++++++++++++++++++- .../imspy_predictors/intensity/predictors.py | 14 ++++++- .../src/imspy_predictors/rt/predictors.py | 12 +++++- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py b/packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py index 0683a6c0..b59dff90 100644 --- a/packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py +++ b/packages/imspy-predictors/src/imspy_predictors/ccs/predictors.py @@ -633,7 +633,37 @@ def fine_tune_model( else: sequences = list(data.sequence_modified.values) - inv_mob = data.ims.values + # Drop training rows whose charge is outside the model's + # one-hot domain (1..4). They cannot be one-hot encoded without + # tripping F.one_hot's CUDA index assertion. Charges of 5+ leak + # through sage matching at ~0.15% on HeLa-class data even when + # precursor_charge is configured as [2, 4]. + n_total = len(sequences) + valid_mask = (charges >= 1) & (charges <= 4) + n_invalid = int((~valid_mask).sum()) + if n_invalid: + warnings.warn( + f"fine_tune_model: dropping {n_invalid} of {n_total} training " + f"PSMs with charge outside [1, 4].", + RuntimeWarning, + stacklevel=2, + ) + sequences = [s for s, m in zip(sequences, valid_mask) if m] + mz = [m for m, ok in zip(mz, valid_mask) if ok] + charges = charges[valid_mask] + inv_mob = data.ims.values[valid_mask] + else: + inv_mob = data.ims.values + + if len(sequences) == 0: + warnings.warn( + "fine_tune_model: no PSMs in valid charge range; skipping fine-tune.", + RuntimeWarning, + stacklevel=2, + ) + self._finetune_history = {"epochs": [], "train_loss": [], "val_loss": []} + return + ccs = np.array([ one_over_k0_to_ccs(i, m, z) for i, m, z in zip(inv_mob, mz, charges) @@ -681,6 +711,7 @@ def fine_tune_model( ) checkpoint = InMemoryCheckpoint(patience=patience) + history = {"epochs": [], "train_loss": [], "val_loss": []} for epoch in range(epochs): # Training self.model.train() @@ -704,6 +735,7 @@ def fine_tune_model( loss.backward() optimizer.step() train_loss += loss.item() + train_loss /= max(len(train_loader), 1) # Validation self.model.eval() @@ -727,8 +759,12 @@ def fine_tune_model( val_loss /= max(len(val_loader), 1) scheduler.step(val_loss) + history["epochs"].append(epoch) + history["train_loss"].append(float(train_loss)) + history["val_loss"].append(float(val_loss)) + if verbose and epoch % 10 == 0: - print(f"Epoch {epoch}: val_loss={val_loss:.4f}") + print(f"Epoch {epoch}: train_loss={train_loss:.4f} val_loss={val_loss:.4f}") if checkpoint.step(val_loss, self.model): if verbose: @@ -737,6 +773,7 @@ def fine_tune_model( checkpoint.restore(self.model) self.model.eval() + self._finetune_history = history def simulate_ion_mobilities_pandas( self, diff --git a/packages/imspy-predictors/src/imspy_predictors/intensity/predictors.py b/packages/imspy-predictors/src/imspy_predictors/intensity/predictors.py index a4d84493..81e76081 100644 --- a/packages/imspy-predictors/src/imspy_predictors/intensity/predictors.py +++ b/packages/imspy-predictors/src/imspy_predictors/intensity/predictors.py @@ -906,8 +906,10 @@ def fine_tune_model( ) checkpoint = InMemoryCheckpoint(patience=patience) + history = {"epochs": [], "train_loss": [], "val_loss": []} for epoch in range(epochs): self.model.train() + train_loss = 0.0 for tokens_b, charge_b, ce_b, target_b in train_loader: optimizer.zero_grad() outputs = self.model( @@ -919,6 +921,8 @@ def fine_tune_model( loss = masked_spectral_distance(target_b, pred) loss.backward() optimizer.step() + train_loss += loss.item() + train_loss /= max(len(train_loader), 1) self.model.eval() val_loss = 0.0 @@ -934,8 +938,15 @@ def fine_tune_model( val_loss /= max(len(val_loader), 1) scheduler.step(val_loss) + history["epochs"].append(epoch) + history["train_loss"].append(float(train_loss)) + history["val_loss"].append(float(val_loss)) + if verbose and epoch % 5 == 0: - print(f"Epoch {epoch}: intensity_val_loss={val_loss:.4f}") + print( + f"Epoch {epoch}: intensity train_loss={train_loss:.4f} " + f"val_loss={val_loss:.4f}" + ) if checkpoint.step(val_loss, self.model): if verbose: @@ -944,6 +955,7 @@ def fine_tune_model( checkpoint.restore(self.model) self.model.eval() + self._finetune_history = history def fine_tune_psms( self, diff --git a/packages/imspy-predictors/src/imspy_predictors/rt/predictors.py b/packages/imspy-predictors/src/imspy_predictors/rt/predictors.py index 9f255716..ebaaf8f2 100644 --- a/packages/imspy-predictors/src/imspy_predictors/rt/predictors.py +++ b/packages/imspy-predictors/src/imspy_predictors/rt/predictors.py @@ -433,14 +433,18 @@ def _fine_tune( scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, min_lr=1e-6) checkpoint = InMemoryCheckpoint(patience=patience) + history = {"epochs": [], "train_loss": [], "val_loss": []} for epoch in range(epochs): self.model.train() + train_loss = 0.0 for tokens_b, rt_b in train_loader: optimizer.zero_grad() pred = _extract_prediction_mean(self.model(tokens_b), "rt") loss = F.l1_loss(pred, rt_b) loss.backward() optimizer.step() + train_loss += loss.item() + train_loss /= max(len(train_loader), 1) self.model.eval() val_loss = 0 @@ -451,14 +455,20 @@ def _fine_tune( val_loss /= max(len(val_loader), 1) scheduler.step(val_loss) + history["epochs"].append(epoch) + history["train_loss"].append(float(train_loss)) + history["val_loss"].append(float(val_loss)) + if verbose and epoch % 10 == 0: - print(f"Epoch {epoch}: val_loss={val_loss:.4f}") + print(f"Epoch {epoch}: train_loss={train_loss:.4f} val_loss={val_loss:.4f}") if checkpoint.step(val_loss, self.model): if verbose: print(f"Early stopping at epoch {epoch}") break + self._finetune_history = history + checkpoint.restore(self.model) self.model.eval() From 3302d070b3bc71fcc7d3ecfa9bc914c7a318afb5 Mon Sep 17 00:00:00 2001 From: theGreatHerrLebert Date: Mon, 18 May 2026 17:07:13 +0200 Subject: [PATCH 5/5] imspy-predictors: replace CE-offset calibration with absolute-NCE sweep Add calibrate_nce(): sweeps absolute NCE over high-confidence target PSMs and returns the value maximizing mean spectral angle -- one NCE per run. The model conditions on a per-run NCE scalar (fine-tuned on collision_energy_aligned_normed, domain ~7-43), so calibration must be an absolute sweep, not an offset added to the observed collision energy. - predict_intensities_prosit: use calibrate_nce; set collision_energy_calibrated to the absolute best NCE instead of collision_energy + offset. - get_collision_energy_calibration_factor: kept as a deprecated compat wrapper over calibrate_nce. Fixes the bug where it calibrated on the unmodified sequence while the real prediction used sequence_modified. --- .../imspy_predictors/intensity/__init__.py | 2 + .../imspy_predictors/intensity/predictors.py | 152 ++++++++++++------ 2 files changed, 103 insertions(+), 51 deletions(-) diff --git a/packages/imspy-predictors/src/imspy_predictors/intensity/__init__.py b/packages/imspy-predictors/src/imspy_predictors/intensity/__init__.py index 738675c0..0ca33f94 100644 --- a/packages/imspy-predictors/src/imspy_predictors/intensity/__init__.py +++ b/packages/imspy-predictors/src/imspy_predictors/intensity/__init__.py @@ -3,6 +3,7 @@ from imspy_predictors.intensity.predictors import ( IonIntensityPredictor, Prosit2023TimsTofWrapper, + calibrate_nce, get_collision_energy_calibration_factor, remove_unimod_annotation, predict_fragment_intensities_with_koina, @@ -25,6 +26,7 @@ 'IonIntensityPredictor', 'Prosit2023TimsTofWrapper', # Utilities + 'calibrate_nce', 'get_collision_energy_calibration_factor', 'remove_unimod_annotation', 'post_process_predicted_fragment_spectra', diff --git a/packages/imspy-predictors/src/imspy_predictors/intensity/predictors.py b/packages/imspy-predictors/src/imspy_predictors/intensity/predictors.py index 81e76081..ab6cb0be 100644 --- a/packages/imspy-predictors/src/imspy_predictors/intensity/predictors.py +++ b/packages/imspy-predictors/src/imspy_predictors/intensity/predictors.py @@ -128,21 +128,17 @@ def predict_intensities_prosit( # the intensity predictor model prosit_model = Prosit2023TimsTofWrapper(verbose=False) - # sample for collision energy calibration - sample = list(sorted(psm_collection, key=lambda x: x.hyperscore, reverse=True))[:int(2 ** 11)] - + # Calibrate one absolute NCE for the run (calibrate_nce drops decoys and + # caps the sample internally). The model conditions on a per-run NCE, so + # every PSM is predicted at that single value -- not observed CE + offset. if calibrate_collision_energy: - collision_energy_calibration_factor, _ = get_collision_energy_calibration_factor( - list(filter(lambda match: match.decoy is not True, sample)), - prosit_model, - verbose=verbose - ) - + calibration = calibrate_nce(prosit_model, psm_collection, verbose=verbose) + calibrated_nce = float(calibration["best_nce"]) + for ps in psm_collection: + ps.collision_energy_calibrated = calibrated_nce else: - collision_energy_calibration_factor = 0.0 - - for ps in psm_collection: - ps.collision_energy_calibrated = ps.collision_energy + collision_energy_calibration_factor + for ps in psm_collection: + ps.collision_energy_calibrated = ps.collision_energy intensity_pred = prosit_model.predict_intensities( [p.sequence_modified for p in psm_collection], @@ -162,57 +158,111 @@ def predict_intensities_prosit( psm.prosit_predicted_intensities = psm_intensity.prosit_predicted_intensities -def get_collision_energy_calibration_factor( - sample: List, - model: 'Prosit2023TimsTofWrapper', - lower: int = -30, - upper: int = 30, +def calibrate_nce( + model, + psms: List, + nce_grid: Optional[List[int]] = None, + per_charge: bool = False, + max_sample: int = 2048, verbose: bool = False, -) -> Tuple[float, List[float]]: - """ - Get the collision energy calibration factor for a given sample. +) -> dict: + """Calibrate the absolute normalized collision energy (NCE) for a run. - Note: This function requires sagepy (via imspy-search package). + The intensity model conditions on a per-run NCE scalar -- it was fine-tuned + on ``collision_energy_aligned_normed`` (domain ~7-43). Calibration sweeps + absolute NCE values, predicts every PSM at each, and returns the value that + maximizes the mean predicted-vs-observed spectral angle. + + This is an *absolute* sweep, NOT an offset on the observed collision energy: + the observed CE (e.g. the Bruker mobility-ramped value) is a different + physical quantity and must not be added to. One NCE is returned per run. + + Note: This function requires sagepy (via the imspy-search package). Args: - sample: a list of PeptideSpectrumMatch objects (sagepy Psm objects) - model: a Prosit2023TimsTofWrapper object - lower: lower bound for the search - upper: upper bound for the search - verbose: whether to print progress + model: an intensity predictor exposing ``predict_intensities(sequences, + charges, collision_energies, batch_size=, flatten=True)`` -- e.g. + Prosit2023TimsTofWrapper or DeepPeptideIntensityPredictor. + psms: sagepy Psm objects carrying observed fragments. Decoys are dropped. + nce_grid: absolute NCE values to sweep (default ``range(15, 51)``). + per_charge: also report the best NCE separately per precursor charge. + max_sample: cap the calibration sample; if exceeded, the highest-scoring + PSMs (by hyperscore) are kept. + verbose: whether to print progress. Returns: - Tuple[float, List[float]]: the collision energy calibration factor and the angle similarities + dict: ``{best_nce, curve: [(nce, mean_spectral_angle), ...], n_psms}``; + also ``per_charge: {charge: best_nce}`` when ``per_charge`` is True. """ - associate_fragment_ions_with_prosit_predicted_intensities, Psm = get_sagepy_fragment_utils() - - cos_target, cos_decoy = [], [] + associate_fragment_ions_with_prosit_predicted_intensities, _ = get_sagepy_fragment_utils() + + if nce_grid is None: + nce_grid = list(range(15, 51)) + nce_grid = [int(x) for x in nce_grid] + + targets = [p for p in psms if not getattr(p, "decoy", False)] + if not targets: + raise ValueError("calibrate_nce: no target PSMs to calibrate on") + if max_sample and len(targets) > max_sample: + targets = sorted(targets, key=lambda p: getattr(p, "hyperscore", 0.0), + reverse=True)[:max_sample] + + def _sweep(sample): + # sequence_modified (not sequence) -- the prediction is mod-aware. + seqs = [p.sequence_modified for p in sample] + chgs = np.array([p.charge for p in sample]) + curve = [] + for nce in tqdm(nce_grid, disable=not verbose, desc="calibrating NCE", ncols=100): + intensities = model.predict_intensities( + seqs, chgs, [float(nce)] * len(sample), + batch_size=2048, flatten=True, + ) + scored = associate_fragment_ions_with_prosit_predicted_intensities( + sample, intensities + ) + sa = float(np.mean([x.spectral_angle_similarity for x in scored])) + curve.append((int(nce), sa)) + best = curve[int(np.argmax([s for _, s in curve]))][0] + return int(best), curve + + best_nce, curve = _sweep(targets) + result = {"best_nce": best_nce, "curve": curve, "n_psms": len(targets)} + + if per_charge: + pc = {} + for z in sorted({int(p.charge) for p in targets}): + sub = [p for p in targets if int(p.charge) == z] + if len(sub) < 100: + continue + pc[z], _ = _sweep(sub) + result["per_charge"] = pc if verbose: - print(f"Searching for collision energy calibration factor between {lower} and {upper} ...") - - for i in tqdm(range(lower, upper), disable=not verbose, desc='calibrating CE', ncols=100): - I = model.predict_intensities( - [p.sequence for p in sample], - np.array([p.charge for p in sample]), - [p.collision_energy + i for p in sample], - batch_size=2048, - flatten=True - ) + print(f"calibrate_nce: best NCE = {best_nce} " + f"(mean spectral angle {max(s for _, s in curve):.4f}, " + f"n = {len(targets)})") + return result - psm_i = associate_fragment_ions_with_prosit_predicted_intensities(sample, I) - target = list(filter(lambda x: not x.decoy, psm_i)) - decoy = list(filter(lambda x: x.decoy, psm_i)) - cos_target.append((i, np.mean([x.spectral_angle_similarity for x in target]))) - cos_decoy.append((i, np.mean([x.spectral_angle_similarity for x in decoy]))) - - calibration_factor, similarities = cos_target[np.argmax([x[1] for x in cos_target])][0], [x[1] for x in cos_target] +def get_collision_energy_calibration_factor( + sample: List, + model: 'Prosit2023TimsTofWrapper', + verbose: bool = False, +) -> Tuple[float, List[float]]: + """DEPRECATED -- use :func:`calibrate_nce`. - if verbose: - print(f"Best calibration factor: {calibration_factor}, with similarity: {np.max(np.round(similarities, 2))}") + BEHAVIOR CHANGED. This previously returned an *offset* added to each PSM's + ``collision_energy`` and -- a bug -- calibrated on the unmodified sequence. + It now delegates to :func:`calibrate_nce` and returns the **absolute** best + NCE. Callers must set ``collision_energy_calibrated = best_nce`` directly, + NOT ``collision_energy + factor``. - return calibration_factor, similarities + Returns: + Tuple[float, List[float]]: the absolute best NCE and the per-NCE mean + spectral angles. + """ + calibration = calibrate_nce(model, sample, verbose=verbose) + return float(calibration["best_nce"]), [sa for _, sa in calibration["curve"]] class IonIntensityPredictor(ABC):