Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/Combinatorial_Model_Memory_Issue.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,24 @@ Implementing the above fixes should dramatically reduce memory usage:
[\[4\]](https://raw.githubusercontent.com/bibymaths/phoskintime/global/networkmodel/backend.py#:~:text=) raw.githubusercontent.com

<https://raw.githubusercontent.com/bibymaths/phoskintime/global/networkmodel/backend.py>

## 11. Implemented Memory-Safe Behavior

The combinatorial implementation now keeps the same biological equations, objective values, ranking semantics, CLI options, input formats and output tables, while reducing unnecessary materialization:

- `MODEL=2` still has an inherent `2^n_sites` state space. `Index` now validates each protein after computing `n_states = 1 << n_sites` and raises an informative `MemoryError` before simulation when the per-protein or total state dimension is unsafe. The error explains the exponential scaling and recommends reducing sites, choosing sequential/distributive models, or lowering workload.
- Combinatorial transition order is unchanged: transitions are enumerated by state mask first and site index second, yielding only unset-bit phosphorylation edges. Large proteins no longer require dense `trans_from`, `trans_to` and `trans_site` arrays for RHS evaluation; dense arrays are retained only for small compatibility cases.
- Combinatorial phosphosite signal extraction no longer builds the dense `ns x n_sites` bit matrix. Each site is aggregated with the equivalent bit-mask vector and temporaries are released immediately after use. This changes only memory materialization, not returned DataFrame columns or values.
- `S_cache` for combinatorial mode is now a reusable current-site-rate buffer rather than an unconditional site-by-time dense matrix. The RHS receives the same current rates used by the previous time-bucketed calculation.
- The JAX/Diffrax combinatorial RHS now generates phosphorylation edges per protein from local `n_states`/`n_sites` metadata instead of depending on global dense transition arrays. It still uses JAX-compatible loops and preserves differentiability through kinase scales and site rates.
- Sensitivity trajectory retention is bounded to the configured top-curve count during simulation collection, so all trajectory DataFrames are not accumulated in memory. Output files and ranking/reporting semantics are preserved.
- Multistart optimization already stores only bounded summaries and parameter vectors for starts (`summary`, `parameters`, `best` tables) rather than full trajectories or histories; no trajectory retention is introduced.

### Interpreting guard errors

A guard error means the requested combinatorial topology is too large to run safely in this process. Because every additional phosphosite doubles the number of protein state variables, safe settings depend primarily on the maximum number of sites on any single protein and the number of requested output time points. Recommended actions are:

1. reduce the number of phosphosites included for high-degree proteins;
2. use `MODEL=0` (distributive) or `MODEL=1` (sequential) when combinatorial state occupancy is not required;
3. reduce workload such as the number of starts, sensitivity trajectories, or output times; and
4. rerun after checking the combinatorial diagnostics in the log, which report proteins, `n_sites`, `n_states`, estimated total state dimension, estimated trajectory memory and the dense transition memory that is now avoided.
35 changes: 14 additions & 21 deletions networkmodel/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,6 @@ def make_networkmodel_rhs(sys, slices=None):
N = int(idx.N)
max_sites = int(np.max(idx.n_sites)) if idx.N else 0
max_states = int(np.max(getattr(idx, "n_states", np.ones(idx.N, dtype=np.int32)))) if idx.N else 1
trans_from = jnp.asarray(getattr(sys, "trans_from", np.zeros(0, dtype=np.int32)), dtype=jnp.int32)
trans_to = jnp.asarray(getattr(sys, "trans_to", np.zeros(0, dtype=np.int32)), dtype=jnp.int32)
trans_site = jnp.asarray(getattr(sys, "trans_site", np.zeros(0, dtype=np.int32)), dtype=jnp.int32)
trans_off = jnp.asarray(getattr(sys, "trans_off", np.zeros(N, dtype=np.int32)), dtype=jnp.int32)
trans_n = jnp.asarray(getattr(sys, "trans_n", np.zeros(N, dtype=np.int32)), dtype=jnp.int32)
max_trans = int(np.max(getattr(sys, "trans_n", np.zeros(N, dtype=np.int32)))) if N else 0

def _params(args):
if isinstance(args, dict):
Expand Down Expand Up @@ -348,21 +342,20 @@ def _stepwise(row):
dp_rate = dp_rate + jnp.where(valid_bit, par["Dp_i"][flat_j] + par["D_i"][i], 0.0)
dy = dy.at[pos_m].add(-dp_rate * Pm)

# Explicit phosphorylation transitions from the precomputed sparse
# hypercube graph. Use S_all at the flattened site index so rates
# remain differentiable through kinase scales and site parameters.
for k in range(max_trans):
tr_idx = jnp.minimum(trans_off[i] + k, trans_from.shape[0] - 1)
valid_tr = k < trans_n[i]
frm = trans_from[tr_idx]
to = trans_to[tr_idx]
j = trans_site[tr_idx]
pos_frm = jnp.minimum(p0 + frm, y.shape[0] - 1)
pos_to = jnp.minimum(p0 + to, y.shape[0] - 1)
flat_j = jnp.minimum(s_off + j, S_all.shape[0] - 1)
flux = jnp.where(valid_tr, S_all[flat_j] * y[pos_frm], 0.0)
dy = dy.at[pos_frm].add(-flux)
dy = dy.at[pos_to].add(flux)
# Explicit phosphorylation transitions, generated per protein
# in the same mask/site order as the historical dense arrays.
for m in range(max_states):
valid_m = m < nst
for j in range(max_sites):
bit_unset = ((m >> j) & 1) == 0
valid_tr = valid_m & (j < ns) & bit_unset
to = m | (1 << j)
pos_frm = jnp.minimum(p0 + m, y.shape[0] - 1)
pos_to = jnp.minimum(p0 + to, y.shape[0] - 1)
flat_j = jnp.minimum(s_off + j, S_all.shape[0] - 1)
flux = jnp.where(valid_tr, S_all[flat_j] * y[pos_frm], 0.0)
dy = dy.at[pos_frm].add(-flux)
dy = dy.at[pos_to].add(flux)
else:
P = y[off + 1]
ar = jnp.arange(max_sites, dtype=jnp.int32)
Expand Down
8 changes: 5 additions & 3 deletions networkmodel/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,9 +1052,11 @@ def export_S_rates(sys, idx, output_dir, filename="S_rates_picked.csv", long=Tru

# ---- compute S matrix: shape (total_sites, n_bins) ----
if MODEL == 2:
# Ensure cache matches current optimized c_k
build_S_cache_into(sys.S_cache, sys.W_indptr, sys.W_indices, sys.W_data, sys.kin_Kmat, sys.c_k)
S_mat = np.asarray(sys.S_cache, dtype=np.float64)
# The combinatorial RHS may use a 1-D current-rate work buffer for
# memory safety, but S_rates_picked.csv preserves the historical dense
# site-by-time export shape. Build that dense matrix only for export.
S_mat = np.empty((int(sys.n_W_rows), int(sys.kin_Kmat.shape[1])), dtype=np.float64)
build_S_cache_into(S_mat, sys.W_indptr, sys.W_indices, sys.W_data, sys.kin_Kmat, sys.c_k)
times = np.asarray(sys.kin_grid, dtype=float)
else:
# Dense kinase signal scaled by c_k
Expand Down
103 changes: 53 additions & 50 deletions networkmodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,9 @@ def _bit_index_from_lsb(lsb):
def combinatorial_rhs(
y, dy,
A_i, B_i, C_i, D_i, Dp_i, E_i, tf_scale,
TF_inputs, S_cache, jb,
TF_inputs, S_rates,
offset_y, offset_s,
n_sites, n_states,
trans_from, trans_to, trans_site, trans_off, trans_n
n_sites, n_states
):
"""Evaluate the combinatorial topology right-hand side

Expand All @@ -334,21 +333,13 @@ def combinatorial_rhs(
E_i: Input value used by this routine.
tf_scale: Input value used by this routine.
TF_inputs: Input value used by this routine.
S_cache: Input value used by this routine.
jb: Input value used by this routine.
S_rates: Current per-site kinase signal vector.
offset_y: Input value used by this routine.
offset_s: Input value used by this routine.
n_sites: Input value used by this routine.
n_states: Input value used by this routine.
trans_from: Input value used by this routine.
trans_to: Input value used by this routine.
trans_site: Input value used by this routine.
trans_off: Input value used by this routine.
trans_n: Input value used by this routine.
"""
N = A_i.shape[0]
jb_loc = jb # local binding helps Numba

for i in range(N):
y_start = offset_y[i]
s_start = offset_s[i]
Expand Down Expand Up @@ -421,64 +412,76 @@ def combinatorial_rhs(
dy[base + m] -= dp_rate * Pm

# --- Phosphorylation Loop (Forward Transitions) ---
# Uses pre-calculated sparse graph structure
off = trans_off[i]
ntr = trans_n[i]
for k in range(ntr):
frm = trans_from[off + k]
to = trans_to[off + k]
j = trans_site[off + k]
# Enumerate the same hypercube edges in the same order as the former
# dense transition arrays, but do not materialize O(2^n_sites*n_sites)
# arrays.
for m in range(nstates):
for j in range(ns):
bit = 1 << j
if (m & bit) == 0:
to = m | bit
rate = S_rates[s_start + j]
flux = rate * y[base + m]

# Rate depends on time bucket 'jb_loc'
rate = S_cache[s_start + j, jb_loc]
flux = rate * y[base + frm]
dy[base + m] -= flux
dy[base + to] += flux

dy[base + frm] -= flux
dy[base + to] += flux

def iter_random_transitions_for_sites(n_sites):
"""Yield combinatorial forward transitions for one protein lazily.

def build_random_transitions(idx):
"""Build transition arrays for combinatorial topology

Args:
idx: Input value used by this routine.

Returns:
Computed result from this routine.
The order is identical to the historical dense implementation: state mask
first, then site index, yielding only unset-bit phosphorylation edges.
"""
ns = int(n_sites)
if ns <= 0:
return
nstates = 1 << ns
for m in range(nstates):
for j in range(ns):
if (m & (1 << j)) == 0:
yield m, m | (1 << j), j


def count_random_transitions_for_sites(n_sites):
"""Return the number of combinatorial forward transitions for n sites."""
ns = int(n_sites)
return 0 if ns <= 0 else ns * (1 << (ns - 1))


def build_random_transitions(idx, *, dense_threshold_sites=4):
"""Build small dense transition arrays for compatibility.

Large proteins are represented by metadata only; callers that need all
transitions should use :func:`iter_random_transitions_for_sites` instead.
"""
trans_from = []
trans_to = []
trans_site = []
trans_off = np.zeros(idx.N, dtype=np.int32)
trans_n = np.zeros(idx.N, dtype=np.int32)
dense_available = np.zeros(idx.N, dtype=np.bool_)

cur = 0
for i in range(idx.N):
ns = int(idx.n_sites[i])
trans_off[i] = cur

if ns == 0:
trans_n[i] = 0
continue

nstates = 1 << ns
for m in range(nstates):
for j in range(ns):
# If bit j is NOT set in m, we can transition to m | (1<<j)
if (m & (1 << j)) == 0:
mp = m | (1 << j)
trans_from.append(m)
trans_to.append(mp)
trans_site.append(j)

n_i = len(trans_from) - cur
trans_n[i] = n_i
cur += n_i
trans_n[i] = count_random_transitions_for_sites(ns)
if ns <= int(dense_threshold_sites):
dense_available[i] = True
for frm, to, site in iter_random_transitions_for_sites(ns):
trans_from.append(frm)
trans_to.append(to)
trans_site.append(site)
cur = len(trans_from)
else:
dense_available[i] = False

return (
np.asarray(trans_from, dtype=np.int32),
np.asarray(trans_to, dtype=np.int32),
np.asarray(trans_site, dtype=np.int32),
trans_off,
trans_n,
dense_available,
)
58 changes: 44 additions & 14 deletions networkmodel/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

logger = setup_logger(log_dir=RESULTS_DIR)

COMBINATORIAL_MAX_STATES_PER_PROTEIN = 1 << 16
COMBINATORIAL_MAX_TOTAL_STATE_DIM = 5_000_000



class Index:
"""Map proteins, sites, kinases, and state-vector offsets"""
Expand Down Expand Up @@ -104,7 +108,19 @@ def __init__(self,
self.n_sites = np.array([len(s) for s in self.sites], dtype=np.int32)

if MODEL == 2:
self.n_states = np.array([1 << int(ns) for ns in self.n_sites], dtype=np.int32)
n_states_list = []
for protein, ns_raw in zip(self.proteins, self.n_sites):
ns = int(ns_raw)
nstates = 1 << ns
if nstates > COMBINATORIAL_MAX_STATES_PER_PROTEIN:
raise MemoryError(
f"Unsafe combinatorial MODEL=2 state space for protein {protein!r}: "
f"n_sites={ns} requires n_states=2^n_sites={nstates:,}. "
"MODEL=2 scales exponentially as 2^n_sites; reduce the number of sites, "
"use sequential/distributive models, or lower the workload."
)
n_states_list.append(nstates)
self.n_states = np.asarray(n_states_list, dtype=np.int64)

# Offsets map standard indices to the flattened y-vector
self.offset_y = np.zeros(self.N, dtype=np.int32)
Expand All @@ -126,6 +142,22 @@ def __init__(self,

self.state_dim = curr_y
self.total_sites = int(curr_s)
if MODEL == 2:
if self.state_dim > COMBINATORIAL_MAX_TOTAL_STATE_DIM:
raise MemoryError(
f"Unsafe combinatorial MODEL=2 total state dimension: {self.state_dim:,} state variables. "
"MODEL=2 scales as 2^n_sites per protein; reduce sites, use sequential/distributive "
"models, or lower workload."
)
transition_count = int(sum((int(ns) * (1 << (int(ns) - 1))) if int(ns) > 0 else 0 for ns in self.n_sites))
dense_transition_mb = transition_count * 3 * np.dtype(np.int32).itemsize / (1024 ** 2)
traj_mb = self.state_dim * len(TIME_POINTS_PROTEIN) * np.dtype(np.float64).itemsize / (1024 ** 2)
logger.warning(
"[Model] Combinatorial MODEL=2 diagnostics: proteins=%d n_sites=%s n_states=%s "
"total_state_dim=%d estimated_trajectory=%.2f MiB dense_transition_arrays=%.2f MiB. "
"MODEL=2 scales exponentially as 2^n_sites.",
self.N, self.n_sites.tolist(), self.n_states.tolist(), self.state_dim, traj_mb, dense_transition_mb
)
self.kinase_indices_in_P = [self.p2i[k] for k in self.kinases if k in self.p2i]
self.p2k = {k: i for i, k in enumerate(self.kinases)}

Expand Down Expand Up @@ -269,18 +301,22 @@ def __init__(self, idx, W_global, tf_mat, kin_input, defaults, tf_deg):
# MODEL == 2 specific setup
# ------------------------------------------------------------
if MODEL == 2:
# reusable work buffers (NO allocs in RHS)
# reusable work buffers (NO allocs in RHS). S_cache is a one-column
# compatibility buffer populated from the current kinase signal, not
# a dense site-by-time cache.
self.P_vec_work = np.zeros(self.n_TF_rows, dtype=np.float64)
self.TF_in_work = np.zeros(self.n_TF_rows, dtype=np.float64)
self.S_cache = np.zeros((self.n_W_rows, self.kin_Kmat.shape[1]), dtype=np.float64)
self.S_cache = np.zeros(self.n_W_rows, dtype=np.float64)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep MODEL=2 S-rate exports two-dimensional

In the normal runner path I checked, runner.py calls export_S_rates after optimization for every model; for MODEL==2 that exporter still calls build_S_cache_into(sys.S_cache, ...) and treats sys.S_cache as a (total_sites, n_timebins) matrix before writing S_rates_picked.csv. Changing the cache here to a 1-D current-rate buffer makes combinatorial runs fail during result export rather than producing the expected output, so the exporter needs its own dense matrix or a streaming export path.

Useful? React with 👍 / 👎.


# precomputed transition lists for the combinatorial hypercube graph
# Dense transitions are retained only for small compatibility tests;
# RHS evaluation streams transitions from n_sites/n_states metadata.
(
self.trans_from,
self.trans_to,
self.trans_site,
self.trans_off,
self.trans_n,
self.trans_dense_available,
) = build_random_transitions(idx)

def update(self, c_k, A_i, B_i, C_i, D_i, Dp_i, E_i, tf_scale):
Expand Down Expand Up @@ -418,22 +454,16 @@ def rhs(self, t, y):
elif MODEL == 2:

if self.S_cache is None:
raise ValueError("MODEL==2: System.S_cache is None. simulate_diffrax must set it.")

jb = int(np.searchsorted(self.kin_grid, t, side="right") - 1)
if jb < 0:
jb = 0
elif jb >= self.kin_grid.size:
jb = self.kin_grid.size - 1
raise ValueError("MODEL==2: System.S_cache is None.")
self.S_cache[:] = S_all

combinatorial_rhs(
y, dy,
self.A_i, self.B_i, self.C_i, self.D_i, self.Dp_i, self.E_i, self.tf_scale,
TF_inputs,
self.S_cache, jb,
self.S_cache,
self.idx.offset_y, self.idx.offset_s,
self.idx.n_sites, self.idx.n_states,
self.trans_from, self.trans_to, self.trans_site, self.trans_off, self.trans_n
self.idx.n_sites, self.idx.n_states
)
return dy

Expand Down
7 changes: 7 additions & 0 deletions networkmodel/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,13 @@ def run_sensitivity_analysis(sys, idx, fitted_params, output_dir, metric="l2_nor
"phos_df": dfph,
}
)
# Keep only the bounded set needed for identical top-curve
# reporting instead of retaining every trajectory DataFrame.
if len(trajectory_storage) > int(SENSITIVITY_TOP_CURVES):
trajectory_storage = _select_top_trajectories(
trajectory_storage,
max_items=SENSITIVITY_TOP_CURVES,
)

except Exception as exc:
failed_rows.append(
Expand Down
Loading
Loading