Add optional pure-JAX network_preprocessing module (hyperedge discovery, pruning, motifs, sparse theta, identifiability)#75
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b046cfbe4a
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| min_score, min_support, prune_self_loops, prune_missing_observations): | ||
| flags = jnp.zeros(score.shape, dtype=jnp.uint32) | ||
| flags = jnp.where(score < min_score, flags | FLAG_LOW_SCORE, flags) | ||
| flags = jnp.where(support_count < min_support, flags | FLAG_LOW_SUPPORT, flags) |
There was a problem hiding this comment.
Compare labels before pruning self-loops
When network preprocessing is enabled with the default prune_self_loops=True, this compares ids from two independently sorted label spaces (kinase_labels and substrate_labels are built separately in encode_kinase_network). If the kinase set and substrate set differ, valid edges are removed whenever their per-space ranks match; for example kinases {A,B} and substrates {B,C} make A -> B look like 0 == 0, so the hook can silently drop real kinase triplets before Index/build_W_parallel see them. Use a shared protein id or an explicit same-label mask for self-loop detection.
Useful? React with 👍 / 👎.
| keys = indices[:,0]*jnp.int32(1000003) + indices[:,1]*jnp.int32(1009) + indices[:,2] | ||
| order = jnp.argsort(keys) | ||
| sorted_keys = keys[order] | ||
| new_group = jnp.concatenate([jnp.array([True]), sorted_keys[1:] != sorted_keys[:-1]]) |
There was a problem hiding this comment.
Return empty groups for empty theta inputs
If pruning removes every triplet, identifiability_kernel still prepends True here, producing a one-element gid for zero values. That makes IdentifiabilityDiagnostics.group_id longer than the other arrays and causes the export DataFrame to fail when output_dir is provided; this can happen with valid options such as a high --network-preprocessing-min-score or missing-observation pruning. Handle the zero-key case by returning empty gid/diagnostic arrays.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 242760b6b9
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| src=list(map(int, jnp.asarray(triplets.kinase_ids))); dst=list(map(int, jnp.asarray(triplets.substrate_ids))); sc=list(map(float, jnp.asarray(triplets.score))) | ||
| edge={(a,b):s for a,b,s in zip(src,dst,sc) if a!=b}; rows=[] |
There was a problem hiding this comment.
Build motifs with shared protein ids
When motif detection is enabled and the kinase and substrate label sets differ, this treats kinase_ids and substrate_ids as if they were drawn from one shared node namespace, even though encode_kinase_network builds those label maps independently. For example, with kinases {A, X} and substrates {B, C, D}, edges A -> C, X -> D, and A -> D have integer pairs (0,1), (1,2), (0,2), so the detector reports a feed-forward loop through node 1 even though substrate C is not kinase X. This silently corrupts the exported/default motif diagnostics for networks where the two protein sets are not identical; map labels to a common protein id space before wedge tests and self-edge filtering.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1c948b9841
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| def score_triplets(edge_weight, support_count, site_observed, kinase_observed, kinase_ids, substrate_ids): | ||
| score = jnp.log1p(edge_weight.astype(jnp.float64)) + 0.5*jnp.log1p(support_count.astype(jnp.float64)) | ||
| score = score + 0.25*site_observed.astype(jnp.float64) + 0.25*kinase_observed.astype(jnp.float64) | ||
| score = score - 0.5*(kinase_ids == substrate_ids).astype(jnp.float64) |
There was a problem hiding this comment.
Use shared labels when penalizing self-loops
Separate from the pruning flag, this scoring path still compares IDs from the independently sorted kinase and substrate label spaces. When those label sets differ, a valid edge such as A -> B can receive the self-loop penalty just because both labels have rank 0, so --network-preprocessing-keep-self-loops, score thresholds, or max_triplets ranking can still drop or deprioritize real triplets; compute the penalty from a shared protein namespace or explicit label equality instead.
Useful? React with 👍 / 👎.
| min_triplet_score=float(getattr(args, "network_preprocessing_min_score", 0.0)), | ||
| discovery_threshold=float(getattr(args, "network_preprocessing_discovery_threshold", 0.0)), | ||
| min_support_count=int(getattr(args, "network_preprocessing_min_support", 1)), | ||
| max_triplets=getattr(args, "network_preprocessing_max_triplets", None), |
There was a problem hiding this comment.
Treat a zero CLI cap as unbounded
The TOML config documents max_triplets = 0 as “no explicit cap” and the TOML parser normalizes that to None, but CLI values are copied through unchanged here. If a user runs with --network-preprocessing-max-triplets 0, prune_triplets treats it as a real cap of zero and removes every otherwise retained kinase triplet, causing an empty pruned network; normalize 0 from args the same way as the config loader.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: cb4f8c2d6c
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| df_kin, network_preprocessing_result = preprocess_networkmodel_frames( | ||
| df_kin, df_tf, df_prot, df_pho, df_rna, | ||
| config=prep_config, output_dir=args.output_dir, logger=logger, |
There was a problem hiding this comment.
Reuse the pruned network in standalone analyses
When preprocessing is enabled together with --posterior-sampling or --profile-likelihood, the main objective is built from this pruned df_kin, but the standalone workers rebuild from the original input paths written by write_posterior_payload; build_networkmodel_posterior_context reloads load_data(args) and never reruns this pruning. For any threshold/cap that removes a site/protein, the worker hits the saved-vs-rebuilt theta shape check; when dimensions happen to match, it samples/profiles a different W matrix. Persist the pruned network or pass enough config for the workers to reproduce preprocessing.
Useful? React with 👍 / 👎.
| min_triplet_score=float(getattr(args, "network_preprocessing_min_score", 0.0)), | ||
| discovery_threshold=float(getattr(args, "network_preprocessing_discovery_threshold", 0.0)), | ||
| min_support_count=int(getattr(args, "network_preprocessing_min_support", 1)), | ||
| max_triplets=(None if getattr(args, "network_preprocessing_max_triplets", None) == 0 else getattr(args, "network_preprocessing_max_triplets", None)), |
There was a problem hiding this comment.
Reject negative CLI triplet caps
The TOML loader rejects negative max_triplets, but CLI args are copied through unchanged here. With --network-preprocessing-max-triplets -1, prune_triplets enters the budget branch and Python slicing [:-1] silently drops the lowest-ranked retained triplet instead of reporting invalid input, so a typo can alter experiments. Validate negative CLI values the same way as TOML.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 4088384ece
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| res=preprocess_network(df_kin, phospho_observations=df_pho, protein_observations=df_prot, rna_observations=df_rna, tf_network=df_tf, config=config, output_dir=output_dir, logger=logger) | ||
| enc=res.encoded | ||
| retained={(enc.kinase_labels[int(k)], enc.site_labels[int(s)].split(":",1)[0], enc.site_labels[int(s)].split(":",1)[1]) for k,s in zip(res.pruned.kinase_ids, res.pruned.site_ids)} | ||
| out=df_kin.copy(); mask=[(str(r.kinase).strip().upper(), str(r.protein).strip().upper(), str(r.psite).strip()) in retained for r in out.itertuples()] |
There was a problem hiding this comment.
Collapse duplicate triplets before returning df_kin
When df_kin contains duplicate rows for the same kinase/protein/site (which encode_kinase_network explicitly groups and represents via support_count), this mask keeps every original row for a retained key. Downstream build_W_parallel consumes every row and duplicate sparse coordinates are summed, so the fitted model can use a summed alpha even though preprocessing scored/pruned a single grouped triplet using max alpha; repeated rows can also make the worker CSV disagree with the retained triplet table. Return a grouped/pruned frame with one row per retained triplet and the selected alpha instead of filtering the ungrouped input.
Useful? React with 👍 / 👎.
|
|
||
| def discover_hyperedges(encoded: EncodedNetwork, config: NetworkPreprocessingConfig) -> TripletTable: | ||
| score=score_triplets(encoded.edge_weight, encoded.support_count, encoded.site_observed, encoded.kinase_observed, encoded.kinase_node_ids, encoded.substrate_node_ids) | ||
| return TripletTable(encoded.kinase_ids, encoded.site_ids, encoded.substrate_ids, score, encoded.support_count, jnp.zeros_like(encoded.support_count, dtype=jnp.uint32)) |
There was a problem hiding this comment.
Export pruning reasons for removed triplets
When any threshold/support/missing-observation filter removes triplets, the only exported table that still contains those removed candidates is discovered_hyperedges.csv, but this initializes all flags to zero and prune_triplets later returns only flags[keep]. As a result, pruning diagnostics never show why removed triplets were removed despite the flag bitmasks being computed; carry the pruning flags onto the discovered/full table or export a separate annotated table before applying keep.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 2e68847452
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| def identifiability_kernel(indices, values): | ||
| norms = jnp.abs(values).astype(jnp.float64) | ||
| retained = norms > jnp.finfo(jnp.float64).eps | ||
| keys = indices[:,0]*jnp.int32(1000003) + indices[:,1]*jnp.int32(1009) + indices[:,2] |
There was a problem hiding this comment.
Use collision-free keys for identifiability groups
When identifiability diagnostics run on larger sparse tensors, this int32 linear hash can group unrelated triplets together: for example (kinase=1, site=0, substrate=0) and (0, 991, 84) both produce 1000003. That silently corrupts group_id/redundancy diagnostics for valid networks with around a thousand sites and enough substrates; use lexicographic sorting/grouping or a collision-free mixed-radix key based on the actual tensor shape.
Useful? React with 👍 / 👎.
|
You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard. |
Motivation
docs/jax_hyperedge_network_preprocessing.mdto run between data loading and model construction without changing default workflows.float64and are suitable for later JAXopt integration.networkmodel.runnerCLI.Description
network_preprocessingimplementing dataclasses, I/O adapters, JAX kernels, scoring/pruning/motif/identifiability APIs, CSV/NPZ exports, and plotting:network_preprocessing/{dataclasses.py,io_adapters.py,jax_kernels.py,api.py,export.py,plotting.py,__init__.py}.network_preprocessing/jax_kernels.pyusingjax.jit,jax.numpyand explicitly enablingjax_enable_x64to keep float64 behavior.pandasframes to JAX arrays inio_adapters.py, build a COO-like sparse theta (indices,values,shape) inapi.py, and write labelled CSV/NPZ/JSON artifacts viaexport.pyand publication-quality PNGs viaplotting.py.networkmodel.runnerbehind new CLI flags (disabled by default):--enable-network-preprocessingplus scoped flags such as--network-preprocessing-min-score,--network-preprocessing-min-support,--network-preprocessing-max-triplets,--network-preprocessing-prune-missing-observations, and--network-preprocessing-keep-self-loops.Testing
python -m py_compile network_preprocessing/*.py networkmodel/runner.pycompleted successfully.pytest -q tests/test_network_preprocessing.pywas attempted but the environment lacks thejaxpackage and collection failed withModuleNotFoundError: No module named 'jax'; withjaxpresent the added tests exercise discovery, pruning, motif detection, sparse tensor shape/dtypes, identifiability diagnostics, JIT/float64 expectations, CSV/plot exports, and the runner CLI flag behavior.Codex Task