diff --git a/docs/Combinatorial_Model_Memory_Issue.md b/docs/Combinatorial_Model_Memory_Issue.md index 3a30e66..3d7a3d5 100644 --- a/docs/Combinatorial_Model_Memory_Issue.md +++ b/docs/Combinatorial_Model_Memory_Issue.md @@ -279,3 +279,24 @@ Implementing the above fixes should dramatically reduce memory usage: [\[4\]](https://raw.githubusercontent.com/bibymaths/phoskintime/global/networkmodel/backend.py#:~:text=) raw.githubusercontent.com + +## 11. Implemented Memory-Safe Behavior + +The combinatorial implementation now keeps the same biological equations, objective values, ranking semantics, CLI options, input formats and output tables, while reducing unnecessary materialization: + +- `MODEL=2` still has an inherent `2^n_sites` state space. `Index` now validates each protein after computing `n_states = 1 << n_sites` and raises an informative `MemoryError` before simulation when the per-protein or total state dimension is unsafe. The error explains the exponential scaling and recommends reducing sites, choosing sequential/distributive models, or lowering workload. +- Combinatorial transition order is unchanged: transitions are enumerated by state mask first and site index second, yielding only unset-bit phosphorylation edges. Large proteins no longer require dense `trans_from`, `trans_to` and `trans_site` arrays for RHS evaluation; dense arrays are retained only for small compatibility cases. +- Combinatorial phosphosite signal extraction no longer builds the dense `ns x n_sites` bit matrix. Each site is aggregated with the equivalent bit-mask vector and temporaries are released immediately after use. This changes only memory materialization, not returned DataFrame columns or values. +- `S_cache` for combinatorial mode is now a reusable current-site-rate buffer rather than an unconditional site-by-time dense matrix. The RHS receives the same current rates used by the previous time-bucketed calculation. +- The JAX/Diffrax combinatorial RHS now generates phosphorylation edges per protein from local `n_states`/`n_sites` metadata instead of depending on global dense transition arrays. It still uses JAX-compatible loops and preserves differentiability through kinase scales and site rates. +- Sensitivity trajectory retention is bounded to the configured top-curve count during simulation collection, so all trajectory DataFrames are not accumulated in memory. Output files and ranking/reporting semantics are preserved. +- Multistart optimization already stores only bounded summaries and parameter vectors for starts (`summary`, `parameters`, `best` tables) rather than full trajectories or histories; no trajectory retention is introduced. + +### Interpreting guard errors + +A guard error means the requested combinatorial topology is too large to run safely in this process. Because every additional phosphosite doubles the number of protein state variables, safe settings depend primarily on the maximum number of sites on any single protein and the number of requested output time points. Recommended actions are: + +1. reduce the number of phosphosites included for high-degree proteins; +2. use `MODEL=0` (distributive) or `MODEL=1` (sequential) when combinatorial state occupancy is not required; +3. reduce workload such as the number of starts, sensitivity trajectories, or output times; and +4. rerun after checking the combinatorial diagnostics in the log, which report proteins, `n_sites`, `n_states`, estimated total state dimension, estimated trajectory memory and the dense transition memory that is now avoided. diff --git a/networkmodel/backend.py b/networkmodel/backend.py index bc9ba5b..9ef6626 100644 --- a/networkmodel/backend.py +++ b/networkmodel/backend.py @@ -254,12 +254,6 @@ def make_networkmodel_rhs(sys, slices=None): N = int(idx.N) max_sites = int(np.max(idx.n_sites)) if idx.N else 0 max_states = int(np.max(getattr(idx, "n_states", np.ones(idx.N, dtype=np.int32)))) if idx.N else 1 - trans_from = jnp.asarray(getattr(sys, "trans_from", np.zeros(0, dtype=np.int32)), dtype=jnp.int32) - trans_to = jnp.asarray(getattr(sys, "trans_to", np.zeros(0, dtype=np.int32)), dtype=jnp.int32) - trans_site = jnp.asarray(getattr(sys, "trans_site", np.zeros(0, dtype=np.int32)), dtype=jnp.int32) - trans_off = jnp.asarray(getattr(sys, "trans_off", np.zeros(N, dtype=np.int32)), dtype=jnp.int32) - trans_n = jnp.asarray(getattr(sys, "trans_n", np.zeros(N, dtype=np.int32)), dtype=jnp.int32) - max_trans = int(np.max(getattr(sys, "trans_n", np.zeros(N, dtype=np.int32)))) if N else 0 def _params(args): if isinstance(args, dict): @@ -348,21 +342,20 @@ def _stepwise(row): dp_rate = dp_rate + jnp.where(valid_bit, par["Dp_i"][flat_j] + par["D_i"][i], 0.0) dy = dy.at[pos_m].add(-dp_rate * Pm) - # Explicit phosphorylation transitions from the precomputed sparse - # hypercube graph. Use S_all at the flattened site index so rates - # remain differentiable through kinase scales and site parameters. - for k in range(max_trans): - tr_idx = jnp.minimum(trans_off[i] + k, trans_from.shape[0] - 1) - valid_tr = k < trans_n[i] - frm = trans_from[tr_idx] - to = trans_to[tr_idx] - j = trans_site[tr_idx] - pos_frm = jnp.minimum(p0 + frm, y.shape[0] - 1) - pos_to = jnp.minimum(p0 + to, y.shape[0] - 1) - flat_j = jnp.minimum(s_off + j, S_all.shape[0] - 1) - flux = jnp.where(valid_tr, S_all[flat_j] * y[pos_frm], 0.0) - dy = dy.at[pos_frm].add(-flux) - dy = dy.at[pos_to].add(flux) + # Explicit phosphorylation transitions, generated per protein + # in the same mask/site order as the historical dense arrays. + for m in range(max_states): + valid_m = m < nst + for j in range(max_sites): + bit_unset = ((m >> j) & 1) == 0 + valid_tr = valid_m & (j < ns) & bit_unset + to = m | (1 << j) + pos_frm = jnp.minimum(p0 + m, y.shape[0] - 1) + pos_to = jnp.minimum(p0 + to, y.shape[0] - 1) + flat_j = jnp.minimum(s_off + j, S_all.shape[0] - 1) + flux = jnp.where(valid_tr, S_all[flat_j] * y[pos_frm], 0.0) + dy = dy.at[pos_frm].add(-flux) + dy = dy.at[pos_to].add(flux) else: P = y[off + 1] ar = jnp.arange(max_sites, dtype=jnp.int32) diff --git a/networkmodel/export.py b/networkmodel/export.py index 8c24474..d5b9eff 100644 --- a/networkmodel/export.py +++ b/networkmodel/export.py @@ -1052,9 +1052,11 @@ def export_S_rates(sys, idx, output_dir, filename="S_rates_picked.csv", long=Tru # ---- compute S matrix: shape (total_sites, n_bins) ---- if MODEL == 2: - # Ensure cache matches current optimized c_k - build_S_cache_into(sys.S_cache, sys.W_indptr, sys.W_indices, sys.W_data, sys.kin_Kmat, sys.c_k) - S_mat = np.asarray(sys.S_cache, dtype=np.float64) + # The combinatorial RHS may use a 1-D current-rate work buffer for + # memory safety, but S_rates_picked.csv preserves the historical dense + # site-by-time export shape. Build that dense matrix only for export. + S_mat = np.empty((int(sys.n_W_rows), int(sys.kin_Kmat.shape[1])), dtype=np.float64) + build_S_cache_into(S_mat, sys.W_indptr, sys.W_indices, sys.W_data, sys.kin_Kmat, sys.c_k) times = np.asarray(sys.kin_grid, dtype=float) else: # Dense kinase signal scaled by c_k diff --git a/networkmodel/models.py b/networkmodel/models.py index 8284048..97eaf42 100644 --- a/networkmodel/models.py +++ b/networkmodel/models.py @@ -316,10 +316,9 @@ def _bit_index_from_lsb(lsb): def combinatorial_rhs( y, dy, A_i, B_i, C_i, D_i, Dp_i, E_i, tf_scale, - TF_inputs, S_cache, jb, + TF_inputs, S_rates, offset_y, offset_s, - n_sites, n_states, - trans_from, trans_to, trans_site, trans_off, trans_n + n_sites, n_states ): """Evaluate the combinatorial topology right-hand side @@ -334,21 +333,13 @@ def combinatorial_rhs( E_i: Input value used by this routine. tf_scale: Input value used by this routine. TF_inputs: Input value used by this routine. - S_cache: Input value used by this routine. - jb: Input value used by this routine. + S_rates: Current per-site kinase signal vector. offset_y: Input value used by this routine. offset_s: Input value used by this routine. n_sites: Input value used by this routine. n_states: Input value used by this routine. - trans_from: Input value used by this routine. - trans_to: Input value used by this routine. - trans_site: Input value used by this routine. - trans_off: Input value used by this routine. - trans_n: Input value used by this routine. """ N = A_i.shape[0] - jb_loc = jb # local binding helps Numba - for i in range(N): y_start = offset_y[i] s_start = offset_s[i] @@ -421,59 +412,70 @@ def combinatorial_rhs( dy[base + m] -= dp_rate * Pm # --- Phosphorylation Loop (Forward Transitions) --- - # Uses pre-calculated sparse graph structure - off = trans_off[i] - ntr = trans_n[i] - for k in range(ntr): - frm = trans_from[off + k] - to = trans_to[off + k] - j = trans_site[off + k] + # Enumerate the same hypercube edges in the same order as the former + # dense transition arrays, but do not materialize O(2^n_sites*n_sites) + # arrays. + for m in range(nstates): + for j in range(ns): + bit = 1 << j + if (m & bit) == 0: + to = m | bit + rate = S_rates[s_start + j] + flux = rate * y[base + m] - # Rate depends on time bucket 'jb_loc' - rate = S_cache[s_start + j, jb_loc] - flux = rate * y[base + frm] + dy[base + m] -= flux + dy[base + to] += flux - dy[base + frm] -= flux - dy[base + to] += flux +def iter_random_transitions_for_sites(n_sites): + """Yield combinatorial forward transitions for one protein lazily. -def build_random_transitions(idx): - """Build transition arrays for combinatorial topology - - Args: - idx: Input value used by this routine. - - Returns: - Computed result from this routine. + The order is identical to the historical dense implementation: state mask + first, then site index, yielding only unset-bit phosphorylation edges. + """ + ns = int(n_sites) + if ns <= 0: + return + nstates = 1 << ns + for m in range(nstates): + for j in range(ns): + if (m & (1 << j)) == 0: + yield m, m | (1 << j), j + + +def count_random_transitions_for_sites(n_sites): + """Return the number of combinatorial forward transitions for n sites.""" + ns = int(n_sites) + return 0 if ns <= 0 else ns * (1 << (ns - 1)) + + +def build_random_transitions(idx, *, dense_threshold_sites=4): + """Build small dense transition arrays for compatibility. + + Large proteins are represented by metadata only; callers that need all + transitions should use :func:`iter_random_transitions_for_sites` instead. """ trans_from = [] trans_to = [] trans_site = [] trans_off = np.zeros(idx.N, dtype=np.int32) trans_n = np.zeros(idx.N, dtype=np.int32) + dense_available = np.zeros(idx.N, dtype=np.bool_) cur = 0 for i in range(idx.N): ns = int(idx.n_sites[i]) trans_off[i] = cur - - if ns == 0: - trans_n[i] = 0 - continue - - nstates = 1 << ns - for m in range(nstates): - for j in range(ns): - # If bit j is NOT set in m, we can transition to m | (1< COMBINATORIAL_MAX_STATES_PER_PROTEIN: + raise MemoryError( + f"Unsafe combinatorial MODEL=2 state space for protein {protein!r}: " + f"n_sites={ns} requires n_states=2^n_sites={nstates:,}. " + "MODEL=2 scales exponentially as 2^n_sites; reduce the number of sites, " + "use sequential/distributive models, or lower the workload." + ) + n_states_list.append(nstates) + self.n_states = np.asarray(n_states_list, dtype=np.int64) # Offsets map standard indices to the flattened y-vector self.offset_y = np.zeros(self.N, dtype=np.int32) @@ -126,6 +142,22 @@ def __init__(self, self.state_dim = curr_y self.total_sites = int(curr_s) + if MODEL == 2: + if self.state_dim > COMBINATORIAL_MAX_TOTAL_STATE_DIM: + raise MemoryError( + f"Unsafe combinatorial MODEL=2 total state dimension: {self.state_dim:,} state variables. " + "MODEL=2 scales as 2^n_sites per protein; reduce sites, use sequential/distributive " + "models, or lower workload." + ) + transition_count = int(sum((int(ns) * (1 << (int(ns) - 1))) if int(ns) > 0 else 0 for ns in self.n_sites)) + dense_transition_mb = transition_count * 3 * np.dtype(np.int32).itemsize / (1024 ** 2) + traj_mb = self.state_dim * len(TIME_POINTS_PROTEIN) * np.dtype(np.float64).itemsize / (1024 ** 2) + logger.warning( + "[Model] Combinatorial MODEL=2 diagnostics: proteins=%d n_sites=%s n_states=%s " + "total_state_dim=%d estimated_trajectory=%.2f MiB dense_transition_arrays=%.2f MiB. " + "MODEL=2 scales exponentially as 2^n_sites.", + self.N, self.n_sites.tolist(), self.n_states.tolist(), self.state_dim, traj_mb, dense_transition_mb + ) self.kinase_indices_in_P = [self.p2i[k] for k in self.kinases if k in self.p2i] self.p2k = {k: i for i, k in enumerate(self.kinases)} @@ -269,18 +301,22 @@ def __init__(self, idx, W_global, tf_mat, kin_input, defaults, tf_deg): # MODEL == 2 specific setup # ------------------------------------------------------------ if MODEL == 2: - # reusable work buffers (NO allocs in RHS) + # reusable work buffers (NO allocs in RHS). S_cache is a one-column + # compatibility buffer populated from the current kinase signal, not + # a dense site-by-time cache. self.P_vec_work = np.zeros(self.n_TF_rows, dtype=np.float64) self.TF_in_work = np.zeros(self.n_TF_rows, dtype=np.float64) - self.S_cache = np.zeros((self.n_W_rows, self.kin_Kmat.shape[1]), dtype=np.float64) + self.S_cache = np.zeros(self.n_W_rows, dtype=np.float64) - # precomputed transition lists for the combinatorial hypercube graph + # Dense transitions are retained only for small compatibility tests; + # RHS evaluation streams transitions from n_sites/n_states metadata. ( self.trans_from, self.trans_to, self.trans_site, self.trans_off, self.trans_n, + self.trans_dense_available, ) = build_random_transitions(idx) def update(self, c_k, A_i, B_i, C_i, D_i, Dp_i, E_i, tf_scale): @@ -418,22 +454,16 @@ def rhs(self, t, y): elif MODEL == 2: if self.S_cache is None: - raise ValueError("MODEL==2: System.S_cache is None. simulate_diffrax must set it.") - - jb = int(np.searchsorted(self.kin_grid, t, side="right") - 1) - if jb < 0: - jb = 0 - elif jb >= self.kin_grid.size: - jb = self.kin_grid.size - 1 + raise ValueError("MODEL==2: System.S_cache is None.") + self.S_cache[:] = S_all combinatorial_rhs( y, dy, self.A_i, self.B_i, self.C_i, self.D_i, self.Dp_i, self.E_i, self.tf_scale, TF_inputs, - self.S_cache, jb, + self.S_cache, self.idx.offset_y, self.idx.offset_s, - self.idx.n_sites, self.idx.n_states, - self.trans_from, self.trans_to, self.trans_site, self.trans_off, self.trans_n + self.idx.n_sites, self.idx.n_states ) return dy diff --git a/networkmodel/sensitivity.py b/networkmodel/sensitivity.py index 85c4c55..baa2ace 100644 --- a/networkmodel/sensitivity.py +++ b/networkmodel/sensitivity.py @@ -447,6 +447,13 @@ def run_sensitivity_analysis(sys, idx, fitted_params, output_dir, metric="l2_nor "phos_df": dfph, } ) + # Keep only the bounded set needed for identical top-curve + # reporting instead of retaining every trajectory DataFrame. + if len(trajectory_storage) > int(SENSITIVITY_TOP_CURVES): + trajectory_storage = _select_top_trajectories( + trajectory_storage, + max_items=SENSITIVITY_TOP_CURVES, + ) except Exception as exc: failed_rows.append( diff --git a/networkmodel/simulate.py b/networkmodel/simulate.py index 8cbdf61..5dedcc9 100644 --- a/networkmodel/simulate.py +++ b/networkmodel/simulate.py @@ -11,6 +11,22 @@ from networkmodel.backend import DiffraxSolverConfig, make_networkmodel_rhs, solve_diffrax +def combinatorial_site_signals_streaming(states, n_sites): + """Return per-site combinatorial signals without building an ns x n_sites bit matrix.""" + state_view = np.asarray(states, dtype=np.float64) + if state_view.ndim != 2: + raise ValueError("states must be a two-dimensional time-by-state array") + ns = state_view.shape[1] + n_sites = int(n_sites) + out = np.empty((state_view.shape[0], n_sites), dtype=np.float64) + masks = np.arange(ns, dtype=np.uint64) + for site in range(n_sites): + weights = ((masks >> np.uint64(site)) & np.uint64(1)).astype(np.float64, copy=False) + out[:, site] = state_view @ weights + del weights + return out + + def simulate_diffrax(sys, t_eval, rtol=None, atol=None, max_steps=None, solver_name="Kvaerno4"): """Simulate a System over requested time points with Diffrax @@ -103,20 +119,20 @@ def _bidx(t0: float) -> int: fc_p = np.maximum(tot, 1e-12) / np.maximum(tot[prot_b], 1e-12) rows_p.append(pd.DataFrame({"protein": gene, "time": times, "pred_fc": fc_p})) - # Phospho Sites: Bitwise aggregation - # We map states to sites using a matrix multiplication (State x Bitmask) + # Phospho Sites: Bitwise aggregation, streamed one site at a time + # to avoid an ns x n_sites dense bit matrix. if n_sites > 0: - m = np.arange(ns, dtype=np.uint32)[:, None] - j = np.arange(n_sites, dtype=np.uint32)[None, :] - bits = ((m >> j) & 1).astype(np.float64) # (ns, n_sites) - pho_sites = states @ bits # (T, n_sites) - + masks = np.arange(ns, dtype=np.uint64) for s_idx, psite in enumerate(idx.sites[i]): - sig = pho_sites[:, s_idx] + weights = ((masks >> np.uint64(s_idx)) & np.uint64(1)).astype(np.float64, copy=False) + sig = states @ weights fc = np.maximum(sig, 1e-12) / np.maximum(sig[pho_b], 1e-12) rows_pho.append(pd.DataFrame({ "protein": gene, "psite": psite, "time": times, "pred_fc": fc })) + del weights, sig + del masks + del states else: # --- Standard Model Extraction (Distributive/Sequential) --- diff --git a/tests/test_combinatorial_memory_safe.py b/tests/test_combinatorial_memory_safe.py new file mode 100644 index 0000000..8ae881e --- /dev/null +++ b/tests/test_combinatorial_memory_safe.py @@ -0,0 +1,76 @@ +import numpy as np +import pandas as pd +import pytest + +from networkmodel.models import iter_random_transitions_for_sites, build_random_transitions +from networkmodel.simulate import combinatorial_site_signals_streaming + + +def _dense_transitions(n_sites): + out = [] + for m in range(1 << n_sites): + for j in range(n_sites): + if (m & (1 << j)) == 0: + out.append((m, m | (1 << j), j)) + return out + + +@pytest.mark.parametrize("n_sites", [2, 3, 4]) +def test_streaming_bit_extraction_matches_dense(n_sites): + rng = np.random.default_rng(123 + n_sites) + states = rng.normal(size=(5, 1 << n_sites)) + m = np.arange(1 << n_sites, dtype=np.uint32)[:, None] + j = np.arange(n_sites, dtype=np.uint32)[None, :] + bits = ((m >> j) & 1).astype(np.float64) + expected = states @ bits + actual = combinatorial_site_signals_streaming(states, n_sites) + np.testing.assert_allclose(actual, expected, rtol=0.0, atol=1e-12) + + +@pytest.mark.parametrize("n_sites", [0, 1, 2, 3, 4]) +def test_transition_iterator_matches_dense_order(n_sites): + assert list(iter_random_transitions_for_sites(n_sites)) == _dense_transitions(n_sites) + + +def test_build_random_transitions_keeps_small_dense_compatibility(): + class Idx: + N = 1 + n_sites = np.array([2], dtype=np.int32) + + frm, to, site, off, ntr, dense = build_random_transitions(Idx()) + expected = _dense_transitions(2) + assert off.tolist() == [0] + assert ntr.tolist() == [len(expected)] + assert dense.tolist() == [True] + assert list(zip(frm.tolist(), to.tolist(), site.tolist())) == expected + + +def test_build_random_transitions_avoids_large_dense_storage(): + class Idx: + N = 1 + n_sites = np.array([5], dtype=np.int32) + + frm, to, site, off, ntr, dense = build_random_transitions(Idx(), dense_threshold_sites=4) + assert frm.size == to.size == site.size == 0 + assert off.tolist() == [0] + assert ntr.tolist() == [5 * (1 << 4)] + assert dense.tolist() == [False] + + +def test_tiny_combinatorial_smoke_signals(): + states = np.array([[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 5.0, 7.0]]) + signals = combinatorial_site_signals_streaming(states, 2) + expected = np.array([[6.0, 7.0], [10.0, 12.0]]) + np.testing.assert_allclose(signals, expected, rtol=0.0, atol=1e-12) + + +def test_memory_guard_message(monkeypatch): + import networkmodel.network as network + + monkeypatch.setattr(network, "MODEL", 2) + monkeypatch.setattr(network, "COMBINATORIAL_MAX_STATES_PER_PROTEIN", 4) + interactions = pd.DataFrame( + {"protein": ["P", "P", "P"], "psite": ["S1", "S2", "S3"], "kinase": ["K", "K", "K"]} + ) + with pytest.raises(MemoryError, match="MODEL=2 scales"): + network.Index(interactions)