diff --git a/.Rhistory b/.Rhistory
new file mode 100644
index 0000000..9d6d90b
--- /dev/null
+++ b/.Rhistory
@@ -0,0 +1 @@
+# Developer Notes — InferelatorJL
diff --git a/.gitignore b/.gitignore
index e43b0f9..8f85e90 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,28 @@
.DS_Store
+
+# Manifest is machine/Julia-version specific — each environment generates its own.
+# Only Project.toml is committed. Run `julia --project=. -e 'using Pkg; Pkg.instantiate()'`
+# to regenerate on a new machine or Julia version.
+Manifest.toml
+Manifest.toml.bak
+
+ =========================
+# Editor
+# =========================
+.vscode/
+.idea/
+
+# =========================
+# Logs / temp
+# =========================
+*.log
+
+# =========================
+# Data / outputs (IMPORTANT for your project)
+# =========================
+results/
+output/
+*.tsv
+*.csv
+*.arrow
+
diff --git a/DEVELOPER.md b/DEVELOPER.md
new file mode 100644
index 0000000..37a6bb7
--- /dev/null
+++ b/DEVELOPER.md
@@ -0,0 +1,266 @@
+# Developer Notes — InferelatorJL
+
+Personal reference for working on this codebase.
+
+---
+
+## Starting Julia
+
+```bash
+julia # from any terminal
+```
+
+Julia has four REPL modes. You switch between them with single keystrokes:
+
+| Mode | Prompt | Enter with | Exit with |
+|---|---|---|---|
+| Normal | `julia>` | default | — |
+| Package (Pkg) | `(@v1.X) pkg>` | `]` | Backspace |
+| Shell | `shell>` | `;` | Backspace |
+| Help | `help?>` | `?` | Backspace |
+
+The `@v1.X` in the Pkg prompt shows your Julia version (e.g. `@v1.12` locally, `@v1.7` on the cluster).
+
+---
+
+## One-time setup — local Mac (Julia 1.12)
+
+```julia
+julia> ]
+(@v1.12) pkg> dev /Users/owop7y/Desktop/InferelatorJL
+(@v1.12) pkg> instantiate
+```
+
+Then rebuild PyCall so PyPlot works (required once per Julia version):
+
+```julia
+(@v1.12) pkg> build PyCall
+```
+
+- `dev` registers the package by path — Julia reads directly from your folder, no copying.
+- `instantiate` resolves and downloads all dependencies, generating a local `Manifest.toml`.
+- `build PyCall` compiles the Python–Julia bridge for your current Python and Julia version.
+
+---
+
+## One-time setup — cluster (Julia 1.7.3)
+
+The cluster does not use `dev` mode — you work directly inside the package folder.
+Run this once after cloning or pulling the repo:
+
+```bash
+# On the cluster, inside the InferelatorJL directory:
+julia --project=. -e 'using Pkg; Pkg.instantiate()'
+julia --project=. -e 'using Pkg; Pkg.build("PyCall")'
+```
+
+This generates a fresh `Manifest.toml` on the cluster, resolved specifically for Julia 1.7.3.
+The cluster Manifest is gitignored — it never conflicts with your local one.
+
+---
+
+## Every session
+
+```julia
+using Revise # must come BEFORE the package
+using InferelatorJL # loads the package
+```
+
+**Always load Revise first.** It monitors `src/` for changes and patches the running
+session when you save a file — no restart needed.
+
+---
+
+## Revise: what it can and cannot do
+
+| Change | Revise handles it? |
+|---|---|
+| Edit a function body | ✅ Live — takes effect on next call |
+| Add a new function | ✅ Live |
+| Add a new `include(...)` to `InferelatorJL.jl` | ✅ Live |
+| Add or rename a field in a struct (`Types.jl`) | ❌ Must restart Julia |
+| Change a struct's field type | ❌ Must restart Julia |
+| Rename or delete a struct | ❌ Must restart Julia |
+
+Struct changes are the only thing that forces a restart. Everything else is live.
+
+---
+
+## Running examples
+
+From the REPL after loading the package:
+```julia
+include("examples/interactive_pipeline.jl")
+include("examples/utilityExamples.jl")
+```
+
+From the terminal (no REPL):
+```bash
+julia --project=. examples/run_pipeline.jl
+```
+
+---
+
+## Running tests
+
+Run the test suite directly — works on all Julia versions:
+
+```bash
+# From the terminal, inside the InferelatorJL directory:
+julia --project=. test/runtests.jl
+```
+
+> **Why not `] test`?**
+> In Julia ≥ 1.10, `Pkg.test()` creates a sandbox environment and calls
+> `check_registered()` on all test dependencies. Stdlib packages like `Test`
+> are not in the General registry, so this fails with
+> `"expected package Test to be registered"`.
+> Running the test file directly with `julia --project=.` bypasses the sandbox
+> and works correctly on all Julia versions.
+
+---
+
+## Working across machines and Julia versions
+
+`Manifest.toml` is **gitignored** — it is machine and Julia-version specific.
+Only `Project.toml` is committed. Each machine generates its own Manifest.
+
+| Machine | What to do once | What gets committed |
+|---|---|---|
+| Local Mac (1.12) | `] instantiate` + `] build PyCall` | nothing (Manifest is gitignored) |
+| Cluster (1.7.3) | `Pkg.instantiate()` + `Pkg.build("PyCall")` | nothing |
+| Any new machine | same as above | nothing |
+
+**After `git pull` that changes `Project.toml`** (e.g. new dependency added), re-run:
+```bash
+julia --project=. -e 'using Pkg; Pkg.instantiate()'
+```
+
+**After upgrading Julia** to a new version, re-run both:
+```bash
+julia --project=. -e 'using Pkg; Pkg.instantiate()'
+julia --project=. -e 'using Pkg; Pkg.build("PyCall")'
+```
+
+**If `Pkg.instantiate()` fails** with `"empty intersection"` or `"unsatisfiable requirements"`:
+a specific package's compat bound in `Project.toml` is too narrow for the Julia version
+being used. The error names the package. Widen its bound in `Project.toml` and re-run.
+Example: `GLMNet = "0.4, 0.5, 0.6, 0.7"` → `"0.4, 0.5, 0.6, 0.7, 0.8"`.
+
+---
+
+## Checking what changed / what's loaded
+
+```julia
+# Which version is active?
+using Pkg; Pkg.status("InferelatorJL")
+
+# Where is it loaded from?
+pathof(InferelatorJL)
+
+# What does the package export?
+names(InferelatorJL)
+
+# What fields does a struct have?
+fieldnames(GeneExpressionData)
+fieldnames(GrnData)
+
+# Inspect a loaded struct at runtime
+data = GeneExpressionData()
+propertynames(data)
+```
+
+---
+
+## Switching between dev and release versions
+
+```julia
+# Currently using dev (your local folder):
+] status InferelatorJL # shows InferelatorJL [path] /Users/owop7y/Desktop/InferelatorJL
+
+# Switch to the released version (once it is published):
+] free InferelatorJL # removes the dev pin
+] add InferelatorJL # installs from registry
+
+# Switch back to dev:
+] dev /Users/owop7y/Desktop/InferelatorJL
+```
+
+---
+
+## Project environments (keeping work isolated)
+
+Every directory can have its own `Project.toml`. To work inside a specific project
+environment (e.g., a collaborator's analysis folder):
+
+```julia
+] activate /path/to/project # switch to that environment
+] status # see what is installed there
+] activate # return to your default environment (@v1.X)
+```
+
+When you `] dev .` from inside a project folder, the package is only registered
+in that project, not globally.
+
+---
+
+## Common errors and what they mean
+
+| Error | Cause | Fix |
+|---|---|---|
+| `UndefVarError: InferelatorJL not defined` | Package not loaded | `using InferelatorJL` |
+| `UndefVarError: Revise not defined` | Revise not installed | `] add Revise` |
+| `Cannot redefine struct` | Changed `Types.jl` | Restart Julia |
+| `MethodError: no method matching ...` | Wrong argument types or order | Check function signature with `?functionname` |
+| `KeyError` on a dict field | Field name wrong | `fieldnames(StructType)` to check |
+| `PyCall not properly installed` | PyCall not built for this Julia version | `] build PyCall` |
+| `expected package Test to be registered` | `] test` sandbox bug in Julia ≥ 1.10 | Use `julia --project=. test/runtests.jl` instead |
+| `empty intersection between X@Y.Z and project compatibility` | Compat bound too narrow for installed version | Widen bound for that package in `Project.toml`, then `] instantiate` |
+| `Warning: Package X does not have Y in its dependencies` | Missing dep in `Project.toml` | Add it with `] add Y` |
+
+---
+
+## Package structure quick reference
+
+```
+src/
+ InferelatorJL.jl ← module entry point, all includes and using statements
+ Types.jl ← ALL struct definitions (edit here for new fields)
+ API.jl ← public API (loadData, buildNetwork, etc.)
+ data/ ← data loading functions
+ prior/ ← TF merging
+ utils/ ← DataUtils, NetworkIO, PartialCorrelation
+ grn/ ← pipeline core (PrepareGRN, BuildGRN, AggregateNetworks, RefineTFA)
+ metrics/ ← PR/ROC evaluation and plotting
+
+examples/
+ interactive_pipeline.jl ← step-by-step, public API (your main reference)
+ interactive_pipeline_dev.jl ← step-by-step, internal calls (your dev reference)
+ run_pipeline.jl ← function-wrapped, public API
+ run_pipeline_dev.jl ← function-wrapped, internal calls
+ utilityExamples.jl ← utility function demos, no real data needed
+ plotPR.jl ← PR curve evaluation
+
+test/
+ runtests.jl ← unit tests (run with: julia --project=. test/runtests.jl)
+```
+
+---
+
+## Adding a new exported function
+
+1. Write the function in the appropriate `src/` file.
+2. Add `export myFunction` to the export block in `src/InferelatorJL.jl`.
+3. Add a docstring above the function definition (Julia uses `"""..."""`).
+4. If it is a utility function without real-data requirements, add an example
+ to `examples/utilityExamples.jl`.
+
+---
+
+## Adding a new dependency
+
+1. `] add PackageName` — this updates both `Project.toml` and `Manifest.toml`.
+2. Add `using PackageName` in the appropriate `src/` file.
+3. Add a `[compat]` bound in `Project.toml` for the new package.
+4. **Commit only `Project.toml`** — `Manifest.toml` is gitignored and should not be committed.
+ Other machines will regenerate their own Manifest via `Pkg.instantiate()`.
diff --git a/Manifest.toml b/Manifest.toml
new file mode 100644
index 0000000..c753c5b
--- /dev/null
+++ b/Manifest.toml
@@ -0,0 +1,968 @@
+# This file is machine-generated - editing it directly is not advised
+
+julia_version = "1.12.5"
+manifest_format = "2.0"
+project_hash = "d561f5e18fb60a24178d6af859af5197e5eb087c"
+
+[[deps.Adapt]]
+deps = ["LinearAlgebra", "Requires"]
+git-tree-sha1 = "35ea197a51ce46fcd01c4a44befce0578a1aaeca"
+uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
+version = "4.5.0"
+weakdeps = ["SparseArrays", "StaticArrays"]
+
+ [deps.Adapt.extensions]
+ AdaptSparseArraysExt = "SparseArrays"
+ AdaptStaticArraysExt = "StaticArrays"
+
+[[deps.AliasTables]]
+deps = ["PtrArrays", "Random"]
+git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff"
+uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"
+version = "1.1.3"
+
+[[deps.ArgParse]]
+deps = ["Logging", "TextWrap"]
+git-tree-sha1 = "22cf435ac22956a7b45b0168abbc871176e7eecc"
+uuid = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
+version = "1.2.0"
+
+[[deps.ArgTools]]
+uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
+version = "1.1.2"
+
+[[deps.Arrow]]
+deps = ["ArrowTypes", "BitIntegers", "CodecLz4", "CodecZstd", "ConcurrentUtilities", "DataAPI", "Dates", "EnumX", "Mmap", "PooledArrays", "SentinelArrays", "StringViews", "Tables", "TimeZones", "TranscodingStreams", "UUIDs"]
+git-tree-sha1 = "4a69a3eadc1f7da78d950d1ef270c3a62c1f7e01"
+uuid = "69666777-d1a9-59fb-9406-91d4454c9d45"
+version = "2.8.1"
+
+[[deps.ArrowTypes]]
+deps = ["Sockets", "UUIDs"]
+git-tree-sha1 = "404265cd8128a2515a81d5eae16de90fdef05101"
+uuid = "31f734f8-188a-4ce0-8406-c8a06bd891cd"
+version = "2.3.0"
+
+[[deps.Artifacts]]
+uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
+version = "1.11.0"
+
+[[deps.AxisAlgorithms]]
+deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"]
+git-tree-sha1 = "01b8ccb13d68535d73d2b0c23e39bd23155fb712"
+uuid = "13072b0f-2c55-5437-9ae7-d433b7a33950"
+version = "1.1.0"
+
+[[deps.Base64]]
+uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
+version = "1.11.0"
+
+[[deps.BitIntegers]]
+deps = ["Random"]
+git-tree-sha1 = "091d591a060e43df1dd35faab3ca284925c48e46"
+uuid = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1"
+version = "0.3.7"
+
+[[deps.CSV]]
+deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"]
+git-tree-sha1 = "8d8e0b0f350b8e1c91420b5e64e5de774c2f0f4d"
+uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
+version = "0.10.16"
+
+[[deps.CategoricalArrays]]
+deps = ["Compat", "DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"]
+git-tree-sha1 = "a6f644eb7bbc0171286f0f3ad1ffde8f04be7b83"
+uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597"
+version = "1.1.0"
+
+ [deps.CategoricalArrays.extensions]
+ CategoricalArraysArrowExt = "Arrow"
+ CategoricalArraysJSONExt = "JSON"
+ CategoricalArraysRecipesBaseExt = "RecipesBase"
+ CategoricalArraysSentinelArraysExt = "SentinelArrays"
+ CategoricalArraysStatsBaseExt = "StatsBase"
+ CategoricalArraysStructTypesExt = "StructTypes"
+
+ [deps.CategoricalArrays.weakdeps]
+ Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
+ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
+ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
+ SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
+ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
+ StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
+
+[[deps.ChainRulesCore]]
+deps = ["Compat", "LinearAlgebra"]
+git-tree-sha1 = "e4c6a16e77171a5f5e25e9646617ab1c276c5607"
+uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+version = "1.26.0"
+weakdeps = ["SparseArrays"]
+
+ [deps.ChainRulesCore.extensions]
+ ChainRulesCoreSparseArraysExt = "SparseArrays"
+
+[[deps.CodecLz4]]
+deps = ["Lz4_jll", "TranscodingStreams"]
+git-tree-sha1 = "d58afcd2833601636b48ee8cbeb2edcb086522c2"
+uuid = "5ba52731-8f18-5e0d-9241-30f10d1ec561"
+version = "0.4.6"
+
+[[deps.CodecZlib]]
+deps = ["TranscodingStreams", "Zlib_jll"]
+git-tree-sha1 = "962834c22b66e32aa10f7611c08c8ca4e20749a9"
+uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
+version = "0.7.8"
+
+[[deps.CodecZstd]]
+deps = ["TranscodingStreams", "Zstd_jll"]
+git-tree-sha1 = "da54a6cd93c54950c15adf1d336cfd7d71f51a56"
+uuid = "6b39b394-51ab-5f42-8807-6242bab2b4c2"
+version = "0.8.7"
+
+[[deps.ColorTypes]]
+deps = ["FixedPointNumbers", "Random"]
+git-tree-sha1 = "67e11ee83a43eb71ddc950302c53bf33f0690dfe"
+uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
+version = "0.12.1"
+weakdeps = ["StyledStrings"]
+
+ [deps.ColorTypes.extensions]
+ StyledStringsExt = "StyledStrings"
+
+[[deps.Colors]]
+deps = ["ColorTypes", "FixedPointNumbers", "Reexport"]
+git-tree-sha1 = "37ea44092930b1811e666c3bc38065d7d87fcc74"
+uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
+version = "0.13.1"
+
+[[deps.Combinatorics]]
+git-tree-sha1 = "c761b00e7755700f9cdf5b02039939d1359330e1"
+uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
+version = "1.1.0"
+
+[[deps.Compat]]
+deps = ["TOML", "UUIDs"]
+git-tree-sha1 = "9d8a54ce4b17aa5bdce0ea5c34bc5e7c340d16ad"
+uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
+version = "4.18.1"
+weakdeps = ["Dates", "LinearAlgebra"]
+
+ [deps.Compat.extensions]
+ CompatLinearAlgebraExt = "LinearAlgebra"
+
+[[deps.CompilerSupportLibraries_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
+version = "1.3.0+1"
+
+[[deps.ConcurrentUtilities]]
+deps = ["Serialization", "Sockets"]
+git-tree-sha1 = "21d088c496ea22914fe80906eb5bce65755e5ec8"
+uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb"
+version = "2.5.1"
+
+[[deps.Conda]]
+deps = ["Downloads", "JSON", "VersionParsing"]
+git-tree-sha1 = "8f06b0cfa4c514c7b9546756dbae91fcfbc92dc9"
+uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
+version = "1.10.3"
+
+[[deps.Crayons]]
+git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
+uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
+version = "4.1.1"
+
+[[deps.DataAPI]]
+git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe"
+uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
+version = "1.16.0"
+
+[[deps.DataFrames]]
+deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
+git-tree-sha1 = "d8928e9169ff76c6281f39a659f9bca3a573f24c"
+uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
+version = "1.8.1"
+
+[[deps.DataStructures]]
+deps = ["OrderedCollections"]
+git-tree-sha1 = "e86f4a2805f7f19bec5129bc9150c38208e5dc23"
+uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
+version = "0.19.4"
+
+[[deps.DataValueInterfaces]]
+git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
+uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
+version = "1.0.0"
+
+[[deps.Dates]]
+deps = ["Printf"]
+uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
+version = "1.11.0"
+
+[[deps.DelimitedFiles]]
+deps = ["Mmap"]
+git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae"
+uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
+version = "1.9.1"
+
+[[deps.Distributed]]
+deps = ["Random", "Serialization", "Sockets"]
+uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
+version = "1.11.0"
+
+[[deps.Distributions]]
+deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
+git-tree-sha1 = "fbcc7610f6d8348428f722ecbe0e6cfe22e672c6"
+uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
+version = "0.25.123"
+
+ [deps.Distributions.extensions]
+ DistributionsChainRulesCoreExt = "ChainRulesCore"
+ DistributionsDensityInterfaceExt = "DensityInterface"
+ DistributionsTestExt = "Test"
+
+ [deps.Distributions.weakdeps]
+ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+ DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
+ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+
+[[deps.DocStringExtensions]]
+git-tree-sha1 = "7442a5dfe1ebb773c29cc2962a8980f47221d76c"
+uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
+version = "0.9.5"
+
+[[deps.Downloads]]
+deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
+uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
+version = "1.7.0"
+
+[[deps.EnumX]]
+git-tree-sha1 = "c49898e8438c828577f04b92fc9368c388ac783c"
+uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
+version = "1.0.7"
+
+[[deps.ExprTools]]
+git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec"
+uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
+version = "0.1.10"
+
+[[deps.FileIO]]
+deps = ["Pkg", "Requires", "UUIDs"]
+git-tree-sha1 = "6522cfb3b8fe97bec632252263057996cbd3de20"
+uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
+version = "1.18.0"
+
+ [deps.FileIO.extensions]
+ HTTPExt = "HTTP"
+
+ [deps.FileIO.weakdeps]
+ HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
+
+[[deps.FilePathsBase]]
+deps = ["Compat", "Dates"]
+git-tree-sha1 = "3bab2c5aa25e7840a4b065805c0cdfc01f3068d2"
+uuid = "48062228-2e41-5def-b9a4-89aafe57970f"
+version = "0.9.24"
+weakdeps = ["Mmap", "Test"]
+
+ [deps.FilePathsBase.extensions]
+ FilePathsBaseMmapExt = "Mmap"
+ FilePathsBaseTestExt = "Test"
+
+[[deps.FileWatching]]
+uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
+version = "1.11.0"
+
+[[deps.FillArrays]]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "2f979084d1e13948a3352cf64a25df6bd3b4dca3"
+uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
+version = "1.16.0"
+weakdeps = ["PDMats", "SparseArrays", "StaticArrays", "Statistics"]
+
+ [deps.FillArrays.extensions]
+ FillArraysPDMatsExt = "PDMats"
+ FillArraysSparseArraysExt = "SparseArrays"
+ FillArraysStaticArraysExt = "StaticArrays"
+ FillArraysStatisticsExt = "Statistics"
+
+[[deps.FixedPointNumbers]]
+deps = ["Statistics"]
+git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172"
+uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
+version = "0.8.5"
+
+[[deps.Future]]
+deps = ["Random"]
+uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
+version = "1.11.0"
+
+[[deps.GLM]]
+deps = ["Distributions", "LinearAlgebra", "Printf", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "StatsModels"]
+git-tree-sha1 = "3bcb30438ee1655e3b9c42d97544de7addc9c589"
+uuid = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
+version = "1.9.3"
+
+[[deps.GLMNet]]
+deps = ["DataFrames", "Distributed", "Distributions", "Printf", "Random", "SparseArrays", "StatsBase", "glmnet_jll"]
+git-tree-sha1 = "b873c384d3490304c18224b1d5554cdebaafb60b"
+uuid = "8d5ece8b-de18-5317-b113-243142960cc6"
+version = "0.7.4"
+
+[[deps.HashArrayMappedTries]]
+git-tree-sha1 = "2eaa69a7cab70a52b9687c8bf950a5a93ec895ae"
+uuid = "076d061b-32b6-4027-95e0-9a2c6f6d7e74"
+version = "0.2.0"
+
+[[deps.HypergeometricFunctions]]
+deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
+git-tree-sha1 = "68c173f4f449de5b438ee67ed0c9c748dc31a2ec"
+uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
+version = "0.3.28"
+
+[[deps.InferelatorJL]]
+deps = ["ArgParse", "Arrow", "CSV", "CategoricalArrays", "Colors", "DataFrames", "Dates", "DelimitedFiles", "Distributions", "FileIO", "GLM", "GLMNet", "InlineStrings", "Interpolations", "JLD2", "LinearAlgebra", "Measures", "NamedArrays", "OrderedCollections", "Printf", "ProgressBars", "PyPlot", "Random", "SparseArrays", "Statistics", "StatsBase", "TickTock"]
+path = "."
+uuid = "436bd8d6-fc45-48f7-bc1f-d4dd5aa384ad"
+version = "0.1.0"
+
+[[deps.InlineStrings]]
+git-tree-sha1 = "8f3d257792a522b4601c24a577954b0a8cd7334d"
+uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
+version = "1.4.5"
+weakdeps = ["ArrowTypes", "Parsers"]
+
+ [deps.InlineStrings.extensions]
+ ArrowTypesExt = "ArrowTypes"
+ ParsersExt = "Parsers"
+
+[[deps.InteractiveUtils]]
+deps = ["Markdown"]
+uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
+version = "1.11.0"
+
+[[deps.Interpolations]]
+deps = ["Adapt", "AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"]
+git-tree-sha1 = "88a101217d7cb38a7b481ccd50d21876e1d1b0e0"
+uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
+version = "0.15.1"
+
+ [deps.Interpolations.extensions]
+ InterpolationsUnitfulExt = "Unitful"
+
+ [deps.Interpolations.weakdeps]
+ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
+
+[[deps.InvertedIndices]]
+git-tree-sha1 = "6da3c4316095de0f5ee2ebd875df8721e7e0bdbe"
+uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
+version = "1.3.1"
+
+[[deps.IrrationalConstants]]
+git-tree-sha1 = "b2d91fe939cae05960e760110b328288867b5758"
+uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
+version = "0.2.6"
+
+[[deps.IteratorInterfaceExtensions]]
+git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
+uuid = "82899510-4779-5014-852e-03e436cf321d"
+version = "1.0.0"
+
+[[deps.JLD2]]
+deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "ScopedValues", "TranscodingStreams"]
+git-tree-sha1 = "d97791feefda45729613fafeccc4fbef3f539151"
+uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
+version = "0.5.15"
+
+ [deps.JLD2.extensions]
+ UnPackExt = "UnPack"
+
+ [deps.JLD2.weakdeps]
+ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
+
+[[deps.JLLWrappers]]
+deps = ["Artifacts", "Preferences"]
+git-tree-sha1 = "0533e564aae234aff59ab625543145446d8b6ec2"
+uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
+version = "1.7.1"
+
+[[deps.JSON]]
+deps = ["Dates", "Logging", "Parsers", "PrecompileTools", "StructUtils", "UUIDs", "Unicode"]
+git-tree-sha1 = "b3ad4a0255688dcb895a52fafbaae3023b588a90"
+uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
+version = "1.4.0"
+weakdeps = ["ArrowTypes"]
+
+ [deps.JSON.extensions]
+ JSONArrowExt = ["ArrowTypes"]
+
+[[deps.JuliaSyntaxHighlighting]]
+deps = ["StyledStrings"]
+uuid = "ac6e5ff7-fb65-4e79-a425-ec3bc9c03011"
+version = "1.12.0"
+
+[[deps.LaTeXStrings]]
+git-tree-sha1 = "dda21b8cbd6a6c40d9d02a73230f9d70fed6918c"
+uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
+version = "1.4.0"
+
+[[deps.LibCURL]]
+deps = ["LibCURL_jll", "MozillaCACerts_jll"]
+uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
+version = "0.6.4"
+
+[[deps.LibCURL_jll]]
+deps = ["Artifacts", "LibSSH2_jll", "Libdl", "OpenSSL_jll", "Zlib_jll", "nghttp2_jll"]
+uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
+version = "8.15.0+0"
+
+[[deps.LibGit2]]
+deps = ["LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
+uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
+version = "1.11.0"
+
+[[deps.LibGit2_jll]]
+deps = ["Artifacts", "LibSSH2_jll", "Libdl", "OpenSSL_jll"]
+uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
+version = "1.9.0+0"
+
+[[deps.LibSSH2_jll]]
+deps = ["Artifacts", "Libdl", "OpenSSL_jll"]
+uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
+version = "1.11.3+1"
+
+[[deps.Libdl]]
+uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
+version = "1.11.0"
+
+[[deps.LinearAlgebra]]
+deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
+uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+version = "1.12.0"
+
+[[deps.LogExpFunctions]]
+deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
+git-tree-sha1 = "13ca9e2586b89836fd20cccf56e57e2b9ae7f38f"
+uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
+version = "0.3.29"
+
+ [deps.LogExpFunctions.extensions]
+ LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
+ LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables"
+ LogExpFunctionsInverseFunctionsExt = "InverseFunctions"
+
+ [deps.LogExpFunctions.weakdeps]
+ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+ ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
+ InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
+
+[[deps.Logging]]
+uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
+version = "1.11.0"
+
+[[deps.Lz4_jll]]
+deps = ["Artifacts", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "191686b1ac1ea9c89fc52e996ad15d1d241d1e33"
+uuid = "5ced341a-0733-55b8-9ab6-a4889d929147"
+version = "1.10.1+0"
+
+[[deps.MacroTools]]
+git-tree-sha1 = "1e0228a030642014fe5cfe68c2c0a818f9e3f522"
+uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
+version = "0.5.16"
+
+[[deps.Markdown]]
+deps = ["Base64", "JuliaSyntaxHighlighting", "StyledStrings"]
+uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
+version = "1.11.0"
+
+[[deps.Measures]]
+git-tree-sha1 = "b513cedd20d9c914783d8ad83d08120702bf2c77"
+uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e"
+version = "0.3.3"
+
+[[deps.Missings]]
+deps = ["DataAPI"]
+git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d"
+uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
+version = "1.2.0"
+
+[[deps.Mmap]]
+uuid = "a63ad114-7e13-5084-954f-fe012c677804"
+version = "1.11.0"
+
+[[deps.Mocking]]
+deps = ["Compat", "ExprTools"]
+git-tree-sha1 = "2c140d60d7cb82badf06d8783800d0bcd1a7daa2"
+uuid = "78c3b35d-d492-501b-9361-3d52fe80e533"
+version = "0.8.1"
+
+[[deps.MozillaCACerts_jll]]
+uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
+version = "2025.11.4"
+
+[[deps.NamedArrays]]
+deps = ["Combinatorics", "DelimitedFiles", "InvertedIndices", "LinearAlgebra", "OrderedCollections", "Random", "Requires", "SparseArrays", "Statistics"]
+git-tree-sha1 = "33d258318d9e049d26c02ca31b4843b2c851c0b0"
+uuid = "86f7a689-2022-50b4-a561-43c23ac3c673"
+version = "0.10.5"
+
+[[deps.NetworkOptions]]
+uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
+version = "1.3.0"
+
+[[deps.OffsetArrays]]
+git-tree-sha1 = "117432e406b5c023f665fa73dc26e79ec3630151"
+uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
+version = "1.17.0"
+weakdeps = ["Adapt"]
+
+ [deps.OffsetArrays.extensions]
+ OffsetArraysAdaptExt = "Adapt"
+
+[[deps.OpenBLAS_jll]]
+deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
+uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
+version = "0.3.29+0"
+
+[[deps.OpenLibm_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
+version = "0.8.7+0"
+
+[[deps.OpenSSL_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
+version = "3.5.4+0"
+
+[[deps.OpenSpecFun_jll]]
+deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "1346c9208249809840c91b26703912dff463d335"
+uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
+version = "0.5.6+0"
+
+[[deps.OrderedCollections]]
+git-tree-sha1 = "05868e21324cede2207c6f0f466b4bfef6d5e7ee"
+uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
+version = "1.8.1"
+
+[[deps.PDMats]]
+deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
+git-tree-sha1 = "e4cff168707d441cd6bf3ff7e4832bdf34278e4a"
+uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
+version = "0.11.37"
+weakdeps = ["StatsBase"]
+
+ [deps.PDMats.extensions]
+ StatsBaseExt = "StatsBase"
+
+[[deps.Parsers]]
+deps = ["Dates", "PrecompileTools", "UUIDs"]
+git-tree-sha1 = "7d2f8f21da5db6a806faf7b9b292296da42b2810"
+uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
+version = "2.8.3"
+
+[[deps.Pkg]]
+deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"]
+uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
+version = "1.12.1"
+weakdeps = ["REPL"]
+
+ [deps.Pkg.extensions]
+ REPLExt = "REPL"
+
+[[deps.PooledArrays]]
+deps = ["DataAPI", "Future"]
+git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3"
+uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
+version = "1.4.3"
+
+[[deps.PrecompileTools]]
+deps = ["Preferences"]
+git-tree-sha1 = "07a921781cab75691315adc645096ed5e370cb77"
+uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
+version = "1.3.3"
+
+[[deps.Preferences]]
+deps = ["TOML"]
+git-tree-sha1 = "8b770b60760d4451834fe79dd483e318eee709c4"
+uuid = "21216c6a-2e73-6563-6e65-726566657250"
+version = "1.5.2"
+
+[[deps.PrettyTables]]
+deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "REPL", "Reexport", "StringManipulation", "Tables"]
+git-tree-sha1 = "624de6279ab7d94fc9f672f0068107eb6619732c"
+uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
+version = "3.3.2"
+
+ [deps.PrettyTables.extensions]
+ PrettyTablesTypstryExt = "Typstry"
+
+ [deps.PrettyTables.weakdeps]
+ Typstry = "f0ed7684-a786-439e-b1e3-3b82803b501e"
+
+[[deps.Printf]]
+deps = ["Unicode"]
+uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
+version = "1.11.0"
+
+[[deps.ProgressBars]]
+deps = ["Printf"]
+git-tree-sha1 = "b437cdb0385ed38312d91d9c00c20f3798b30256"
+uuid = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
+version = "1.5.1"
+
+[[deps.PtrArrays]]
+git-tree-sha1 = "4fbbafbc6251b883f4d2705356f3641f3652a7fe"
+uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
+version = "1.4.0"
+
+[[deps.PyCall]]
+deps = ["Conda", "Dates", "Libdl", "LinearAlgebra", "MacroTools", "Serialization", "VersionParsing"]
+git-tree-sha1 = "9816a3826b0ebf49ab4926e2b18842ad8b5c8f04"
+uuid = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
+version = "1.96.4"
+
+[[deps.PyPlot]]
+deps = ["Colors", "LaTeXStrings", "PyCall", "Sockets", "Test", "VersionParsing"]
+git-tree-sha1 = "d2c2b8627bbada1ba00af2951946fb8ce6012c05"
+uuid = "d330b81b-6aea-500a-939a-2ce795aea3ee"
+version = "2.11.6"
+
+[[deps.QuadGK]]
+deps = ["DataStructures", "LinearAlgebra"]
+git-tree-sha1 = "9da16da70037ba9d701192e27befedefb91ec284"
+uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
+version = "2.11.2"
+
+ [deps.QuadGK.extensions]
+ QuadGKEnzymeExt = "Enzyme"
+
+ [deps.QuadGK.weakdeps]
+ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+
+[[deps.REPL]]
+deps = ["InteractiveUtils", "JuliaSyntaxHighlighting", "Markdown", "Sockets", "StyledStrings", "Unicode"]
+uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
+version = "1.11.0"
+
+[[deps.Random]]
+deps = ["SHA"]
+uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+version = "1.11.0"
+
+[[deps.Ratios]]
+deps = ["Requires"]
+git-tree-sha1 = "1342a47bf3260ee108163042310d26f2be5ec90b"
+uuid = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439"
+version = "0.4.5"
+weakdeps = ["FixedPointNumbers"]
+
+ [deps.Ratios.extensions]
+ RatiosFixedPointNumbersExt = "FixedPointNumbers"
+
+[[deps.Reexport]]
+git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
+uuid = "189a3867-3050-52da-a836-e630ba90ab69"
+version = "1.2.2"
+
+[[deps.Requires]]
+deps = ["UUIDs"]
+git-tree-sha1 = "62389eeff14780bfe55195b7204c0d8738436d64"
+uuid = "ae029012-a4dd-5104-9daa-d747884805df"
+version = "1.3.1"
+
+[[deps.Rmath]]
+deps = ["Random", "Rmath_jll"]
+git-tree-sha1 = "5b3d50eb374cea306873b371d3f8d3915a018f0b"
+uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
+version = "0.9.0"
+
+[[deps.Rmath_jll]]
+deps = ["Artifacts", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8"
+uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
+version = "0.5.1+0"
+
+[[deps.SHA]]
+uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
+version = "0.7.0"
+
+[[deps.ScopedValues]]
+deps = ["HashArrayMappedTries", "Logging"]
+git-tree-sha1 = "ac4b837d89a58c848e85e698e2a2514e9d59d8f6"
+uuid = "7e506255-f358-4e82-b7e4-beb19740aa63"
+version = "1.6.0"
+
+[[deps.Scratch]]
+deps = ["Dates"]
+git-tree-sha1 = "9b81b8393e50b7d4e6d0a9f14e192294d3b7c109"
+uuid = "6c6a2e73-6563-6170-7368-637461726353"
+version = "1.3.0"
+
+[[deps.SentinelArrays]]
+deps = ["Dates", "Random"]
+git-tree-sha1 = "ebe7e59b37c400f694f52b58c93d26201387da70"
+uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
+version = "1.4.9"
+
+[[deps.Serialization]]
+uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
+version = "1.11.0"
+
+[[deps.SharedArrays]]
+deps = ["Distributed", "Mmap", "Random", "Serialization"]
+uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
+version = "1.11.0"
+
+[[deps.ShiftedArrays]]
+git-tree-sha1 = "503688b59397b3307443af35cd953a13e8005c16"
+uuid = "1277b4bf-5013-50f5-be3d-901d8477a67a"
+version = "2.0.0"
+
+[[deps.Sockets]]
+uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
+version = "1.11.0"
+
+[[deps.SortingAlgorithms]]
+deps = ["DataStructures"]
+git-tree-sha1 = "64d974c2e6fdf07f8155b5b2ca2ffa9069b608d9"
+uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
+version = "1.2.2"
+
+[[deps.SparseArrays]]
+deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
+uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
+version = "1.12.0"
+
+[[deps.SpecialFunctions]]
+deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
+git-tree-sha1 = "2700b235561b0335d5bef7097a111dc513b8655e"
+uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
+version = "2.7.2"
+weakdeps = ["ChainRulesCore"]
+
+ [deps.SpecialFunctions.extensions]
+ SpecialFunctionsChainRulesCoreExt = "ChainRulesCore"
+
+[[deps.StaticArrays]]
+deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
+git-tree-sha1 = "246a8bb2e6667f832eea063c3a56aef96429a3db"
+uuid = "90137ffa-7385-5640-81b9-e52037218182"
+version = "1.9.18"
+weakdeps = ["ChainRulesCore", "Statistics"]
+
+ [deps.StaticArrays.extensions]
+ StaticArraysChainRulesCoreExt = "ChainRulesCore"
+ StaticArraysStatisticsExt = "Statistics"
+
+[[deps.StaticArraysCore]]
+git-tree-sha1 = "6ab403037779dae8c514bad259f32a447262455a"
+uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
+version = "1.4.4"
+
+[[deps.Statistics]]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0"
+uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
+version = "1.11.1"
+weakdeps = ["SparseArrays"]
+
+ [deps.Statistics.extensions]
+ SparseArraysExt = ["SparseArrays"]
+
+[[deps.StatsAPI]]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "178ed29fd5b2a2cfc3bd31c13375ae925623ff36"
+uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
+version = "1.8.0"
+
+[[deps.StatsBase]]
+deps = ["AliasTables", "DataAPI", "DataStructures", "IrrationalConstants", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
+git-tree-sha1 = "aceda6f4e598d331548e04cc6b2124a6148138e3"
+uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
+version = "0.34.10"
+
+[[deps.StatsFuns]]
+deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
+git-tree-sha1 = "91f091a8716a6bb38417a6e6f274602a19aaa685"
+uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
+version = "1.5.2"
+
+ [deps.StatsFuns.extensions]
+ StatsFunsChainRulesCoreExt = "ChainRulesCore"
+ StatsFunsInverseFunctionsExt = "InverseFunctions"
+
+ [deps.StatsFuns.weakdeps]
+ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+ InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
+
+[[deps.StatsModels]]
+deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Printf", "REPL", "ShiftedArrays", "SparseArrays", "StatsAPI", "StatsBase", "StatsFuns", "Tables"]
+git-tree-sha1 = "08786db4a1346d17d0a8d952d2e66fd00fa18192"
+uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
+version = "0.7.9"
+
+[[deps.StringManipulation]]
+deps = ["PrecompileTools"]
+git-tree-sha1 = "d05693d339e37d6ab134c5ab53c29fce5ee5d7d5"
+uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e"
+version = "0.4.4"
+
+[[deps.StringViews]]
+git-tree-sha1 = "f2dcb92855b31ad92fe8f079d4f75ac57c93e4b8"
+uuid = "354b36f9-a18e-4713-926e-db85100087ba"
+version = "1.3.7"
+
+[[deps.StructUtils]]
+deps = ["Dates", "UUIDs"]
+git-tree-sha1 = "fa95b3b097bcef5845c142ea2e085f1b2591e92c"
+uuid = "ec057cc2-7a8d-4b58-b3b3-92acb9f63b42"
+version = "2.7.1"
+
+ [deps.StructUtils.extensions]
+ StructUtilsMeasurementsExt = ["Measurements"]
+ StructUtilsStaticArraysCoreExt = ["StaticArraysCore"]
+ StructUtilsTablesExt = ["Tables"]
+
+ [deps.StructUtils.weakdeps]
+ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
+ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
+ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
+
+[[deps.StyledStrings]]
+uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
+version = "1.11.0"
+
+[[deps.SuiteSparse]]
+deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
+uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
+
+[[deps.SuiteSparse_jll]]
+deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
+uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
+version = "7.8.3+2"
+
+[[deps.TOML]]
+deps = ["Dates"]
+uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
+version = "1.0.3"
+
+[[deps.TZJData]]
+deps = ["Artifacts"]
+git-tree-sha1 = "72df96b3a595b7aab1e101eb07d2a435963a97e2"
+uuid = "dc5dba14-91b3-4cab-a142-028a31da12f7"
+version = "1.5.0+2025b"
+
+[[deps.TableTraits]]
+deps = ["IteratorInterfaceExtensions"]
+git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39"
+uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
+version = "1.0.1"
+
+[[deps.Tables]]
+deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"]
+git-tree-sha1 = "f2c1efbc8f3a609aadf318094f8fc5204bdaf344"
+uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
+version = "1.12.1"
+
+[[deps.Tar]]
+deps = ["ArgTools", "SHA"]
+uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
+version = "1.10.0"
+
+[[deps.Test]]
+deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
+uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+version = "1.11.0"
+
+[[deps.TextWrap]]
+git-tree-sha1 = "43044b737fa70bc12f6105061d3da38f881a3e3c"
+uuid = "b718987f-49a8-5099-9789-dcd902bef87d"
+version = "1.0.2"
+
+[[deps.TickTock]]
+deps = ["Dates"]
+git-tree-sha1 = "385ff4318d1159050cb129f908804ff95b830de0"
+uuid = "9ff05d80-102d-5586-aa04-3a8bd1a90d20"
+version = "1.3.0"
+
+[[deps.TimeZones]]
+deps = ["Artifacts", "Dates", "Downloads", "InlineStrings", "Mocking", "Printf", "Scratch", "TZJData", "Unicode", "p7zip_jll"]
+git-tree-sha1 = "d422301b2a1e294e3e4214061e44f338cafe18a2"
+uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53"
+version = "1.22.2"
+
+ [deps.TimeZones.extensions]
+ TimeZonesRecipesBaseExt = "RecipesBase"
+
+ [deps.TimeZones.weakdeps]
+ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
+
+[[deps.TranscodingStreams]]
+git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742"
+uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
+version = "0.11.3"
+
+[[deps.UUIDs]]
+deps = ["Random", "SHA"]
+uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
+version = "1.11.0"
+
+[[deps.Unicode]]
+uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
+version = "1.11.0"
+
+[[deps.VersionParsing]]
+git-tree-sha1 = "58d6e80b4ee071f5efd07fda82cb9fbe17200868"
+uuid = "81def892-9a0e-5fdd-b105-ffc91e053289"
+version = "1.3.0"
+
+[[deps.WeakRefStrings]]
+deps = ["DataAPI", "InlineStrings", "Parsers"]
+git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23"
+uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
+version = "1.4.2"
+
+[[deps.WoodburyMatrices]]
+deps = ["LinearAlgebra", "SparseArrays"]
+git-tree-sha1 = "248a7031b3da79a127f14e5dc5f417e26f9f6db7"
+uuid = "efce3f68-66dc-5838-9240-27a6d6f5f9b6"
+version = "1.1.0"
+
+[[deps.WorkerUtilities]]
+git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7"
+uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60"
+version = "1.6.1"
+
+[[deps.Zlib_jll]]
+deps = ["Libdl"]
+uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
+version = "1.3.1+2"
+
+[[deps.Zstd_jll]]
+deps = ["Artifacts", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "446b23e73536f84e8037f5dce465e92275f6a308"
+uuid = "3161d3a3-bdf6-5164-811a-617609db77b4"
+version = "1.5.7+1"
+
+[[deps.glmnet_jll]]
+deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
+git-tree-sha1 = "31adae3b983b579a1fbd7cfd43a4bc0d224c2f5a"
+uuid = "78c6b45d-5eaf-5d68-bcfb-a5a2cb06c27f"
+version = "2.0.13+0"
+
+[[deps.libblastrampoline_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
+version = "5.15.0+0"
+
+[[deps.nghttp2_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
+version = "1.64.0+1"
+
+[[deps.p7zip_jll]]
+deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
+uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
+version = "17.7.0+0"
diff --git a/UtilsRprog/DownSamplingObject.R b/Processing/DownSamplingObject.R
similarity index 100%
rename from UtilsRprog/DownSamplingObject.R
rename to Processing/DownSamplingObject.R
diff --git a/UtilsRprog/saveNormCountsArrowFIle.R b/Processing/saveNormCountsArrowFIle.R
similarity index 100%
rename from UtilsRprog/saveNormCountsArrowFIle.R
rename to Processing/saveNormCountsArrowFIle.R
diff --git a/Project.toml b/Project.toml
new file mode 100755
index 0000000..2d32763
--- /dev/null
+++ b/Project.toml
@@ -0,0 +1,61 @@
+name = "InferelatorJL"
+uuid = "436bd8d6-fc45-48f7-bc1f-d4dd5aa384ad"
+version = "0.1.0"
+
+[deps]
+ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
+Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
+CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
+CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
+Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
+DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
+Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
+DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
+Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
+FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
+GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
+GLMNet = "8d5ece8b-de18-5317-b113-243142960cc6"
+InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
+Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
+JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
+LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+Measures = "442fdcdd-2543-5da2-b0f3-8c86c306513e"
+NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
+OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
+Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
+ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
+PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
+Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
+Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
+StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
+TickTock = "9ff05d80-102d-5586-aa04-3a8bd1a90d20"
+
+[compat]
+julia = "1.7"
+ArgParse = "1"
+Arrow = "2"
+CSV = "0.10"
+CategoricalArrays = "0.10, 1"
+Colors = "0.12, 0.13"
+DataFrames = "1"
+Distributions = "0.25"
+FileIO = "1"
+GLM = "1"
+GLMNet = "0.4, 0.5, 0.6, 0.7"
+InlineStrings = "1"
+Interpolations = "0.13, 0.14, 0.15"
+JLD2 = "0.4, 0.5"
+Measures = "0.3"
+NamedArrays = "0.9, 0.10"
+OrderedCollections = "1"
+ProgressBars = "1"
+PyPlot = "2"
+StatsBase = "0.33, 0.34"
+TickTock = "1"
+
+[extras]
+Test = "8dfed614-e22c-358a-adea-af005a22d00b"
+
+[targets]
+test = ["Test"]
diff --git a/README.md b/README.md
index 6f75c6e..1728d3b 100755
--- a/README.md
+++ b/README.md
@@ -1,14 +1,218 @@
-# Inferelator_Julia
+# InferelatorJL
-This repository contains a workflow for inference of transcriptional regulatory networks (TRNs) from gene expression data and prior information, as described in:
+A Julia package for inference of transcriptional regulatory networks (GRNs) from
+gene expression data and prior chromatin accessibility or binding information,
+using the **mLASSO-StARS** algorithm.
-Miraldi et al., Leveraging chromatin accessibility for transcriptional regulatory network inference in T Helper 17 Cells.
+
-The main workflow is found here: \
-[workflow.jl](https://github.com/MiraldiLab/Inferelator_Julia/blob/main/Testing/workflow.jl) \
-This workflow can benefit from parallelization, with significant speedups from using multiple cores. The workflow can be run on the BMI cluster with n cores using the bat file: \
-[workflow.bat](https://github.com/MiraldiLab/Inferelator_Julia/blob/main/Testing/workflow.bat) \
+---
+## Overview
+InferelatorJL infers genome-scale gene regulatory networks (GRNs) by combining
+gene expression data with a prior network (e.g., from ATAC-seq or ChIP-seq) to
+identify transcription factor (TF) → target gene regulatory relationships.
+The method uses **multi-scale LASSO with Stability Approach to Regularization
+Selection (mLASSO-StARS)**, a penalized regression framework in which the prior
+network biases edge selection and stability selection across bootstrap subsamples
+determines a data-driven sparsity level.
-
+**Key features:**
+
+- Supports bulk RNA-seq, pseudobulk, and single-cell expression data
+- Estimates **TF activity (TFA)** via least squares as an alternative to raw TF mRNA
+- Builds separate networks for TFA and TF mRNA predictors, then combines them
+- Prior-weighted LASSO penalties reduce false positives for prior-supported edges
+- Network-level or gene-level instability thresholding
+- Built-in PR / ROC evaluation against gold-standard interaction sets
+- Fully multithreaded subsampling loop
+
+---
+
+## Installation
+
+**Requirements:** Julia ≥ 1.7, Python (for PyPlot/matplotlib)
+
+```julia
+# From the Julia REPL:
+] add https://github.com/miraldilab/InferelatorJL.jl
+
+# Install all dependencies:
+] instantiate
+```
+
+**For development / local installation:**
+```julia
+] dev /path/to/InferelatorJL
+] instantiate
+```
+
+---
+
+## Quick start
+
+```julia
+using InferelatorJL
+
+# Step 1 — load expression data
+data = loadData(exprFile, targFile, regFile)
+
+# Steps 2–3 — merge degenerate TFs, build prior, estimate TFA
+priorData, mergedTFs = loadPrior(data, priorFile)
+estimateTFA(priorData, data; outputDir = dirOut)
+
+# Step 4 — infer GRN (TFA predictors)
+buildNetwork(data, priorData;
+ tfaMode = true,
+ outputDir = joinpath(dirOut, "TFA"))
+
+# Step 5 — aggregate TFA + mRNA networks
+aggregateNetworks([tfaEdges, mrnaEdges];
+ method = :max,
+ outputDir = joinpath(dirOut, "Combined"))
+
+# Step 6 — refine TFA using the consensus network as a new prior
+refineTFA(combinedNetFile, data, mergedTFs; outputDir = dirOut)
+```
+
+See [`examples/interactive_pipeline.jl`](examples/interactive_pipeline.jl) for a
+fully annotated step-by-step example.
+
+---
+
+## Pipeline
+
+The full pipeline has six steps. Each step is a single function call.
+
+| Step | Function | Description |
+|---|---|---|
+| 1 | `loadData` | Load expression matrix; filter to target genes and regulators |
+| 2–3 | `loadPrior` + `estimateTFA` | Process prior, merge degenerate TFs, estimate TF activity via least squares |
+| 4 | `buildNetwork` | Run mLASSO-StARS for one predictor mode (TFA or TF mRNA) |
+| 5 | `aggregateNetworks` | Combine TFA and mRNA edge lists into a consensus network |
+| 6 | `refineTFA` | Re-estimate TFA using the consensus network as a data-driven prior |
+| — | `evaluateNetwork` | Evaluate any network against gold standards (PR/ROC curves) |
+
+Steps 4–6 are typically run twice (TFA mode and TF mRNA mode) and the results
+aggregated in Step 5.
+
+---
+
+## Input files
+
+| File | Format | Description |
+|---|---|---|
+| Expression matrix | Tab-delimited TSV or Apache Arrow | Genes × samples; genes in rows, samples in columns; first column is gene names |
+| Target gene list | Plain text, one gene per line | Genes to model as regression targets |
+| Regulator list | Plain text, one TF per line | Candidate transcription factors |
+| Prior network | Sparse TSV (TF × gene, empty first header) | Prior regulatory connections; values are edge weights (binary or continuous) |
+| Prior penalties | Same format as prior network | Prior(s) used to set per-edge LASSO penalties; can be the same as prior network |
+| TFA gene list | Plain text, one gene per line | Optional: restrict TFA estimation to a gene subset |
+
+---
+
+## Output files
+
+All outputs are written under the directory specified by `outputDir`.
+
+| File | Description |
+|---|---|
+| `edges.tsv` | Full ranked edge table: TF, gene, signed quantile, stability, partial correlation, inPrior flag |
+| `edges_subset.tsv` | Top edges after applying the `meanEdgesPerGene` cap |
+| `instability_*.jld2` | Raw instability arrays across the λ grid (for diagnostics) |
+| `combined_.tsv` | Aggregated network (long format) |
+| `combined__sp.tsv` | Aggregated network (sparse prior format, for downstream use) |
+
+---
+
+## Key parameters
+
+| Parameter | Default | Description |
+|---|---|---|
+| `totSS` | 80 | Total bootstrap subsamples for instability estimation |
+| `subsampleFrac` | 0.63 | Fraction of samples drawn per subsample |
+| `targetInstability` | 0.05 | Instability threshold for λ selection |
+| `lambdaBias` | `[0.5]` | Penalty reduction factor for prior-supported edges (0 = no prior, 1 = uniform) |
+| `meanEdgesPerGene` | 20 | Maximum retained edges per target gene |
+| `instabilityLevel` | `"Network"` | `"Network"`: single λ for all genes; `"Gene"`: per-gene λ |
+| `zScoreTFA` | `true` | Z-score expression before TFA estimation |
+| `zScoreLASSO` | `true` | Z-score expression before LASSO regression |
+| `method` | `:max` | Network aggregation rule: `:max`, `:mean`, or `:min` stability |
+
+---
+
+## API reference
+
+### Data loading
+```julia
+data = loadData(exprFile, targFile, regFile; tfaGeneFile="", epsilon=0.01)
+priorData, mergedTFs = loadPrior(data, priorFile; minTargets=3)
+```
+
+### Core pipeline
+```julia
+estimateTFA(priorData, data; edgeSS=0, zScoreTFA=true, outputDir=".")
+buildNetwork(data, priorData; tfaMode=true, totSS=80, lambdaBias=[0.5], ...)
+aggregateNetworks(netFiles; method=:max, meanEdgesPerGene=20, outputDir=".")
+refineTFA(combinedNetFile, data, mergedTFs; zScoreTFA=true, outputDir=".")
+```
+
+### Evaluation
+```julia
+evaluateNetwork(gsFile, netFile; gsRegsFile="", targGeneFile="", outputDir=".")
+```
+
+### Utilities
+```julia
+convertToLong(df) # wide prior matrix → long format
+convertToWide(df; indices=(2,1,3)) # long → wide
+frobeniusNormalize(M, :column) # normalize matrix columns or rows
+binarizeNumeric!(df) # continuous prior → binary
+mergeDFs(dfs, :Gene, "sum") # merge prior DataFrames
+completeDF(df, :Gene, allGenes, allTFs) # align to full gene/TF universe
+writeTSVWithEmptyFirstHeader(df, path) # write sparse prior format
+check_column_norms(M) # verify unit column norms
+writeNetworkTable!(buildGrn; outputDir=".") # write edges.tsv from BuildGrn struct
+saveData(data, tfaData, grnData, buildGrn, dir, "checkpoint.jld2")
+```
+
+---
+
+## Examples
+
+| File | Description |
+|---|---|
+| [`interactive_pipeline.jl`](examples/interactive_pipeline.jl) | Full pipeline, step-by-step in the REPL, public API |
+| [`run_pipeline.jl`](examples/run_pipeline.jl) | Full pipeline wrapped in `runInferelator()` for batch use |
+| [`utilityExamples.jl`](examples/utilityExamples.jl) | All utility functions demonstrated with synthetic data; no input files needed |
+| [`plotPR.jl`](examples/plotPR.jl) | Evaluate networks against gold standards and generate PR curve plots |
+
+---
+
+## Data structures
+
+| Struct | Populated by | Key fields |
+|---|---|---|
+| `GeneExpressionData` | `loadData` | `expressionMat`, `geneNames`, `sampleNames`, `tfNames` |
+| `PriorTFAData` | `loadPrior`, `estimateTFA` | `priorMat`, `medTfas`, `tfNames`, `targetGenes` |
+| `mergedTFsResult` | `loadPrior` | `mergedTFs`, `tfNames`, `mergeTfLocVec` |
+| `GrnData` | `buildNetwork` internals | `predictorMat`, `penaltyMat`, `stabilityMat`, `subsampleMat` |
+| `BuildGrn` | `buildNetwork` | `regs`, `targs`, `signedQuantile`, `rankings`, `networkMat` |
+
+---
+
+## Citation
+
+If you use InferelatorJL in your work, please cite:
+
+> Miraldi ER, Pokrovskii M, Watters A, et al.
+> **Leveraging chromatin accessibility for transcriptional regulatory network inference in T Helper 17 Cells.**
+> *PLOS Computational Biology*, 2019.
+> https://doi.org/10.1371/journal.pcbi.1006979
+
+---
+
+## License
+
+See [LICENSE](LICENSE) for details.
diff --git a/archive/GRN.jl b/archive/GRN.jl
new file mode 100755
index 0000000..82ccbf8
--- /dev/null
+++ b/archive/GRN.jl
@@ -0,0 +1,116 @@
+module GRN
+
+ include("utilsGRN.jl")
+
+ using ..Data
+ using ..DataUtils
+ using ..PriorTFA
+ using ..MergeDegenerate
+ using ..NetworkIO
+
+ using GLMNet
+ using Random, Statistics, StatsBase
+ using Distributions
+ using DataFrames, CategoricalArrays, SparseArrays, CSV, DelimitedFiles
+ using LinearAlgebra
+ using TickTock
+ using JLD2
+ using ProgressBars
+ using Printf, Dates
+ using PyPlot
+ using Statistics
+ using CSV
+ using NamedArrays
+ using ArgParse
+
+ mutable struct GrnData
+ predictorMat::Matrix{Float64}
+ penaltyMat::Matrix{Float64}
+ allPredictors::Vector{String}
+ subsamps::Matrix{Int64}
+ responseMat::Matrix{Float64}
+ maxLambdaNet::Float64
+ minLambdaNet::Float64
+ minLambdas::Matrix{Float64}
+ maxLambdas::Matrix{Float64}
+ netInstabilitiesUb::Vector{Float64}
+ netInstabilitiesLb::Vector{Float64}
+ instabilitiesUb::Matrix{Float64}
+ instabilitiesLb::Matrix{Float64}
+ netInstabilities::Vector{Float64}
+ geneInstabilities::Matrix{Float64}
+ lambdaRange::Vector{Float64}
+ lambdaRangeGene::Vector{Vector{Float64}}
+ stabilityMat::Array{Float64}
+ priorMatProcessed::Matrix{Float64}
+ betas::Array{Float64,3}
+ function GrnData()
+ return new(
+ Matrix{Float64}(undef, 0, 0), # predictorMat
+ Matrix{Float64}(undef, 0, 0), # penaltyMat
+ [], # allPredictors
+ Matrix{Int64}(undef, 0, 0), # subsamps
+ Matrix{Int64}(undef, 0, 0), # responseMat
+ 0.0, # maxLambdasNet
+ 0.0, # minLambdasNet
+ Matrix{Int64}(undef, 0, 0), # minLambdas
+ Matrix{Int64}(undef, 0, 0), # maxLambdas
+ [], # netInstabilitiesUb
+ [], # netInstabilitiesLb
+ Matrix{Int64}(undef, 0, 0), # instabilitiesUb
+ Matrix{Int64}(undef, 0, 0), # instabilitiesLb
+ [], # netInstabilities
+ Matrix{Int64}(undef, 0, 0), # geneInstabilities
+ [], # lambdaRange
+ Vector{Vector{Float64}}(undef, 0), # lambdaRangesGene
+ Matrix{Int64}(undef, 0, 0), # stabilityMat
+ Matrix{Float64}(undef, 0, 0), # priorMatProcessed
+ Array{Float64,3}(undef, 0, 0, 0) # betas
+ )
+ end
+ end
+
+
+ mutable struct BuildGrn
+ networkStability::Matrix{Float64}
+ lambda::Union{Float64, Vector{Float64}}
+ targs::Vector{String}
+ regs::Vector{String}
+ rankings::Vector{Float64}
+ signedQuantile::Vector{Float64}
+ partialCorrelation::Vector{Float64}
+ inPrior::Vector{String}
+ networkMat::Matrix{Any}
+ networkMatSubset::Matrix{Any}
+ inPriorVec::Vector{Float64}
+ betas::Matrix{Float64}
+ function BuildGrn()
+ return new(
+ Matrix{Float64}(undef, 0, 0), # networkStability
+ 0.0, # lambda
+ [], # targs
+ [], # regs
+ [], # rankings
+ [], # signedQuantile
+ [], # partialCorrelation
+ [], # inPrior
+ Matrix{Float64}(undef, 0, 0), # networkMat
+ Matrix{Float64}(undef, 0, 0), # networkMatSubset
+ [], # inPriorVec
+ Matrix{Float64}(undef, 0, 0), # betas
+ [] # mergeTfLocVec
+ )
+ end
+ end
+
+ # Export only the functions that users need
+ export preparePredictorMat!, preparePenaltyMatrix!, constructSubsamples, bstarsWarmStart, bstartsEstimateInstability,
+ BuildGrn, GrnData, chooseLambda!, rankEdges!, combineGRNs, combineGRNS2
+
+ include("prepareGRN.jl") # functions that prepare predictor matrices etc.
+ include("buildGRN.jl") # functions that use GrnData / BuildGrn
+ include("aggregateNetworks.jl") # combineGRNs
+ include("refineTFA.jl") # combineGRNS2
+
+
+end
\ No newline at end of file
diff --git a/archive/Packages.txt b/archive/Packages.txt
new file mode 100755
index 0000000..0a84b65
--- /dev/null
+++ b/archive/Packages.txt
@@ -0,0 +1,23 @@
+JLD2
+CSV
+Arrow
+DataFrames
+OrderedCollections
+Interpolations
+InlineStrings
+StatsPlots
+PyPlot
+Colors
+Measures
+GLMNet
+Random
+StatsBase
+Distributions
+CategoricalArrays
+SparseArrays
+TickTock
+ProgressBars
+NamedArrays
+ArgParse
+FileIO
+GR
\ No newline at end of file
diff --git a/Testing/Packages.txt b/archive/Testing/Packages.txt
similarity index 100%
rename from Testing/Packages.txt
rename to archive/Testing/Packages.txt
diff --git a/Testing/R2_predict.bat b/archive/Testing/R2_predict.bat
similarity index 100%
rename from Testing/R2_predict.bat
rename to archive/Testing/R2_predict.bat
diff --git a/Testing/R2_predict.jl b/archive/Testing/R2_predict.jl
similarity index 100%
rename from Testing/R2_predict.jl
rename to archive/Testing/R2_predict.jl
diff --git a/Testing/combinePriors.jl b/archive/Testing/combinePriors.jl
similarity index 100%
rename from Testing/combinePriors.jl
rename to archive/Testing/combinePriors.jl
diff --git a/Testing/installRequirements.jl b/archive/Testing/installRequirements.jl
similarity index 100%
rename from Testing/installRequirements.jl
rename to archive/Testing/installRequirements.jl
diff --git a/Testing/misc/README.md b/archive/Testing/misc/README.md
similarity index 100%
rename from Testing/misc/README.md
rename to archive/Testing/misc/README.md
diff --git a/Testing/misc/calcPRinfTRNsMulti.jl b/archive/Testing/misc/calcPRinfTRNsMulti.jl
similarity index 100%
rename from Testing/misc/calcPRinfTRNsMulti.jl
rename to archive/Testing/misc/calcPRinfTRNsMulti.jl
diff --git a/Testing/misc/plotPR.jl b/archive/Testing/misc/plotPR.jl
similarity index 100%
rename from Testing/misc/plotPR.jl
rename to archive/Testing/misc/plotPR.jl
diff --git a/Testing/plotPR.jl b/archive/Testing/plotPR.jl
similarity index 100%
rename from Testing/plotPR.jl
rename to archive/Testing/plotPR.jl
diff --git a/Testing/plotPR_Old.jl b/archive/Testing/plotPR_Old.jl
similarity index 100%
rename from Testing/plotPR_Old.jl
rename to archive/Testing/plotPR_Old.jl
diff --git a/Testing/scratch/PR.jl b/archive/Testing/scratch/PR.jl
similarity index 100%
rename from Testing/scratch/PR.jl
rename to archive/Testing/scratch/PR.jl
diff --git a/Testing/scratch/TestingScript.jl b/archive/Testing/scratch/TestingScript.jl
similarity index 100%
rename from Testing/scratch/TestingScript.jl
rename to archive/Testing/scratch/TestingScript.jl
diff --git a/Testing/scratch/ToDo b/archive/Testing/scratch/ToDo
similarity index 100%
rename from Testing/scratch/ToDo
rename to archive/Testing/scratch/ToDo
diff --git a/Testing/scratch/calcPRinfTRNsMulti.jl b/archive/Testing/scratch/calcPRinfTRNsMulti.jl
similarity index 100%
rename from Testing/scratch/calcPRinfTRNsMulti.jl
rename to archive/Testing/scratch/calcPRinfTRNsMulti.jl
diff --git a/Testing/workflow.bat b/archive/Testing/workflow.bat
similarity index 100%
rename from Testing/workflow.bat
rename to archive/Testing/workflow.bat
diff --git a/Testing/workflow.jl b/archive/Testing/workflow.jl
similarity index 100%
rename from Testing/workflow.jl
rename to archive/Testing/workflow.jl
diff --git a/Testing/workflowCombined.jl b/archive/Testing/workflowCombined.jl
similarity index 100%
rename from Testing/workflowCombined.jl
rename to archive/Testing/workflowCombined.jl
diff --git a/Testing/workflowCombinedMulti.jl b/archive/Testing/workflowCombinedMulti.jl
similarity index 100%
rename from Testing/workflowCombinedMulti.jl
rename to archive/Testing/workflowCombinedMulti.jl
diff --git a/archive/installPackages.jl b/archive/installPackages.jl
new file mode 100755
index 0000000..75e7849
--- /dev/null
+++ b/archive/installPackages.jl
@@ -0,0 +1,35 @@
+# Pkg module for managing packages
+using Pkg
+# Define the path to the file containing required package names
+package_file = "/data/miraldiNB/Michael/Scripts/GRN/InferelatorJL/Packages.txt"
+
+# Read package names from the file
+required_packages = []
+if isfile(package_file)
+ # Open the file and read package names
+ open(package_file) do file
+ for line in eachline(file)
+ line = strip(line)
+ if !isempty(line)
+ push!(required_packages, line)
+ end
+ end
+ end
+else
+ println("Error: Package requirements file '$package_file' not found.")
+ exit(1)
+end
+
+# deps = keys(Pkg.project().dependencies)
+
+# Check and install required packages
+for pkg in required_packages
+ # if !(pkg in deps)
+ try
+ eval(Meta.parse("using $pkg")) # check if package already exist by loading
+ catch
+ println("Installing $pkg...") #install package if not installed
+ Pkg.add(pkg)
+ eval(Meta.parse("using $pkg")) # Load the package after installation
+ end
+end
diff --git a/julia_fxns/CombineTRN/.RData b/archive/julia_fxns/CombineTRN/.RData
similarity index 100%
rename from julia_fxns/CombineTRN/.RData
rename to archive/julia_fxns/CombineTRN/.RData
diff --git a/julia_fxns/CombineTRN/.Rhistory b/archive/julia_fxns/CombineTRN/.Rhistory
similarity index 100%
rename from julia_fxns/CombineTRN/.Rhistory
rename to archive/julia_fxns/CombineTRN/.Rhistory
diff --git a/julia_fxns/CombineTRN/combineTRN1.jl b/archive/julia_fxns/CombineTRN/combineTRN1.jl
similarity index 100%
rename from julia_fxns/CombineTRN/combineTRN1.jl
rename to archive/julia_fxns/CombineTRN/combineTRN1.jl
diff --git a/julia_fxns/CombineTRN/combineTRN2.jl b/archive/julia_fxns/CombineTRN/combineTRN2.jl
similarity index 100%
rename from julia_fxns/CombineTRN/combineTRN2.jl
rename to archive/julia_fxns/CombineTRN/combineTRN2.jl
diff --git a/julia_fxns/buildTRNs_mLassoStARS.jl b/archive/julia_fxns/buildTRNs_mLassoStARS.jl
similarity index 100%
rename from julia_fxns/buildTRNs_mLassoStARS.jl
rename to archive/julia_fxns/buildTRNs_mLassoStARS.jl
diff --git a/julia_fxns/calcAupr.jl b/archive/julia_fxns/calcAupr.jl
similarity index 100%
rename from julia_fxns/calcAupr.jl
rename to archive/julia_fxns/calcAupr.jl
diff --git a/julia_fxns/calcPRinfTRNs.jl b/archive/julia_fxns/calcPRinfTRNs.jl
similarity index 100%
rename from julia_fxns/calcPRinfTRNs.jl
rename to archive/julia_fxns/calcPRinfTRNs.jl
diff --git a/julia_fxns/calcR2predFromStabilities.jl b/archive/julia_fxns/calcR2predFromStabilities.jl
similarity index 100%
rename from julia_fxns/calcR2predFromStabilities.jl
rename to archive/julia_fxns/calcR2predFromStabilities.jl
diff --git a/julia_fxns/estimateInstabilitiesTRNbStARS.jl b/archive/julia_fxns/estimateInstabilitiesTRNbStARS.jl
similarity index 100%
rename from julia_fxns/estimateInstabilitiesTRNbStARS.jl
rename to archive/julia_fxns/estimateInstabilitiesTRNbStARS.jl
diff --git a/julia_fxns/getMLassoStARSinstabilitiesPerGeneAndNet.jl b/archive/julia_fxns/getMLassoStARSinstabilitiesPerGeneAndNet.jl
similarity index 100%
rename from julia_fxns/getMLassoStARSinstabilitiesPerGeneAndNet.jl
rename to archive/julia_fxns/getMLassoStARSinstabilitiesPerGeneAndNet.jl
diff --git a/julia_fxns/getMLassoStARSlambdaRangePerGene.jl b/archive/julia_fxns/getMLassoStARSlambdaRangePerGene.jl
similarity index 100%
rename from julia_fxns/getMLassoStARSlambdaRangePerGene.jl
rename to archive/julia_fxns/getMLassoStARSlambdaRangePerGene.jl
diff --git a/julia_fxns/groupSelection.jl b/archive/julia_fxns/groupSelection.jl
similarity index 100%
rename from julia_fxns/groupSelection.jl
rename to archive/julia_fxns/groupSelection.jl
diff --git a/julia_fxns/importGeneExpGeneLists.jl b/archive/julia_fxns/importGeneExpGeneLists.jl
similarity index 100%
rename from julia_fxns/importGeneExpGeneLists.jl
rename to archive/julia_fxns/importGeneExpGeneLists.jl
diff --git a/julia_fxns/integratePrior_estTFA.jl b/archive/julia_fxns/integratePrior_estTFA.jl
similarity index 100%
rename from julia_fxns/integratePrior_estTFA.jl
rename to archive/julia_fxns/integratePrior_estTFA.jl
diff --git a/julia_fxns/mergeDegeneratePriorTFs.jl b/archive/julia_fxns/mergeDegeneratePriorTFs.jl
similarity index 100%
rename from julia_fxns/mergeDegeneratePriorTFs.jl
rename to archive/julia_fxns/mergeDegeneratePriorTFs.jl
diff --git a/julia_fxns/partialCorrelation.jl b/archive/julia_fxns/partialCorrelation.jl
similarity index 100%
rename from julia_fxns/partialCorrelation.jl
rename to archive/julia_fxns/partialCorrelation.jl
diff --git a/julia_fxns/plotConfidenceDistribution.jl b/archive/julia_fxns/plotConfidenceDistribution.jl
similarity index 100%
rename from julia_fxns/plotConfidenceDistribution.jl
rename to archive/julia_fxns/plotConfidenceDistribution.jl
diff --git a/julia_fxns/plotMetricUtils.jl b/archive/julia_fxns/plotMetricUtils.jl
similarity index 100%
rename from julia_fxns/plotMetricUtils.jl
rename to archive/julia_fxns/plotMetricUtils.jl
diff --git a/julia_fxns/priorUtils.jl b/archive/julia_fxns/priorUtils.jl
similarity index 100%
rename from julia_fxns/priorUtils.jl
rename to archive/julia_fxns/priorUtils.jl
diff --git a/julia_fxns/scratch/mergeDegeneratePriorTFs.jl b/archive/julia_fxns/scratch/mergeDegeneratePriorTFs.jl
similarity index 100%
rename from julia_fxns/scratch/mergeDegeneratePriorTFs.jl
rename to archive/julia_fxns/scratch/mergeDegeneratePriorTFs.jl
diff --git a/julia_fxns/updatePenaltyMatrix.jl b/archive/julia_fxns/updatePenaltyMatrix.jl
similarity index 100%
rename from julia_fxns/updatePenaltyMatrix.jl
rename to archive/julia_fxns/updatePenaltyMatrix.jl
diff --git a/archive/packages01.txt b/archive/packages01.txt
new file mode 100755
index 0000000..c3e7336
--- /dev/null
+++ b/archive/packages01.txt
@@ -0,0 +1,148 @@
+Adapt [79e6a3ab-5dfb-504d-930d-738a2a938a0e]
+AliasTables [66dad0bd-aa9a-41b7-9441-69ab47430ed8]
+ArgParse [c7e460c6-2fb9-53a9-8c5b-16f535851c63]
+ArgTools [0dad84c5-d112-42e6-8d28-ef12dabb789f]
+Arrow [69666777-d1a9-59fb-9406-91d4454c9d45]
+ArrowTypes [31f734f8-188a-4ce0-8406-c8a06bd891cd]
+Artifacts [56f22d72-fd6d-98f1-02f0-08ddc0907c33]
+AxisAlgorithms [13072b0f-2c55-5437-9ae7-d433b7a33950]
+Base [top-level]
+Base64 [2a0f44e3-6c83-55bd-87e4-b1978d98bd5f]
+BitIntegers [c3b6d118-76ef-56ca-8cc7-ebb389d030a1]
+CRC32c [8bf52ea8-c179-5cab-976a-9e18b702a9bc]
+CSV [336ed68f-0bac-5ca0-87d4-7b16caf5d00b]
+CategoricalArrays [324d7699-5711-5eae-9e2f-1d82baa6b597]
+ChainRulesCore [d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4]
+ChangesOfVariables [9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0]
+CodeTracking [da1fd8a2-8d9e-5ec2-8556-3022fb5608a2]
+CodecLz4 [5ba52731-8f18-5e0d-9241-30f10d1ec561]
+CodecZlib [944b1d66-785c-5afd-91f1-9de20f533193]
+CodecZstd [6b39b394-51ab-5f42-8807-6242bab2b4c2]
+ColorTypes [3da002f7-5984-5a60-b8a6-cbb66c0b333f]
+Colors [5ae59095-9a9b-59fe-a467-6f913c188581]
+Combinatorics [861a8166-3701-5b0c-9a16-15d98fcdc6aa]
+Compat [34da2185-b29b-5c13-b0c7-acf172513d20]
+CompilerSupportLibraries_jll [e66e0078-7015-5450-92f7-15fbd957f2ae]
+ConcurrentUtilities [f0e56b4a-5159-44fe-b623-3e5288b988bb]
+Conda [8f4d0f93-b110-5947-807f-2305c1781a2d]
+Core [top-level]
+Crayons [a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f]
+DataAPI [9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a]
+DataFrames [a93c6f00-e57d-5684-b7b6-d8193f3e46c0]
+DataStructures [864edb3b-99cc-5e75-8d2d-829cb0a9cfe8]
+DataValueInterfaces [e2d170a0-9d28-54be-80f0-106bbe20a464]
+Dates [ade2ca70-3891-5945-98fb-dc099432e06a]
+DelimitedFiles [8bb1440f-4735-579b-a4ab-409b98df4dab]
+DensityInterface [b429d917-457f-4dbc-8f4c-0cc954292b1d]
+Distributed [8ba89e20-285c-5b6f-9357-94700520ee1b]
+Distributions [31c24e10-a181-5473-b8eb-7969acd0382f]
+DocStringExtensions [ffbed154-4ef7-542d-bbb7-c09d3a79fcae]
+Downloads [f43a241f-c20a-4ad4-852c-f6b1247861c6]
+EnumX [4e289a0a-7415-4d19-859d-a7e5c4648b56]
+ExprTools [e2ba6199-217a-4e67-a87a-7c52f15ade04]
+FileIO [5789e2e9-d7fb-5bc7-8068-2c6fae9b9549]
+FilePathsBase [48062228-2e41-5def-b9a4-89aafe57970f]
+FileWatching [7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee]
+FillArrays [1a297f60-69ca-5386-bcde-b61e274b549b]
+FixedPointNumbers [53c48c17-4a7d-5ca2-90c5-79b7896eea93]
+Future [9fa8497b-333b-5362-9e8d-4d0656e87820]
+GLMNet [8d5ece8b-de18-5317-b113-243142960cc6]
+HypergeometricFunctions [34004b35-14d8-5ef3-9330-4cdb6864b03a]
+InlineStrings [842dd82b-1e85-43dc-bf29-5d0ee9dffc48]
+InteractiveUtils [b77e0a4c-d291-57a0-90e8-8db25a27a240]
+Interpolations [a98d9a8b-a2ab-59e6-89dd-64a1c18fca59]
+InverseFunctions [3587e190-3f89-42d0-90ee-14403ec27112]
+InvertedIndices [41ab1584-1d38-5bbf-9106-f11c6c58b48f]
+IrrationalConstants [92d709cd-6900-40b7-9082-c6be49f344b6]
+IteratorInterfaceExtensions [82899510-4779-5014-852e-03e436cf321d]
+JLD2 [033835bb-8acc-5ee8-8aae-3f567f8a3819]
+JLLWrappers [692b3bcd-3c85-4b1f-b108-f13ce0eb3210]
+JSON [682c06a0-de6a-54ab-a142-c8b1cf79cde6]
+JuliaInterpreter [aa1ae85d-cabe-5617-a682-6adf51b2e16a]
+LaTeXStrings [b964fa9f-0449-5b57-a5c2-d3ea65f4040f]
+LazyArtifacts [4af54fe1-eca0-43a8-85a7-787d91b784e3]
+LibCURL [b27032c2-a3e7-50c8-80cd-2d36dbcbfd21]
+LibCURL_jll [deac9b47-8bc7-5906-a0fe-35ac56dc84c0]
+LibGit2 [76f85450-5226-5b5a-8eaa-529ad045b433]
+Libdl [8f399da3-3557-5675-b5ff-fb832c97cbdb]
+LinearAlgebra [37e2e46d-f89d-539d-b4ee-838fcccc9c8e]
+LogExpFunctions [2ab3a3ac-af41-5b50-aa03-7779005ae688]
+Logging [56ddb016-857b-54e1-b83d-db4d58db5568]
+LoweredCodeUtils [6f1432cf-f94c-5a45-995e-cdbf5db27b0b]
+Lz4_jll [5ced341a-0733-55b8-9ab6-a4889d929147]
+MacroTools [1914dd2f-81c6-5fcd-8719-6d5c9610ff09]
+Main [top-level]
+Markdown [d6f4376e-aef5-505a-96c1-9c027394607a]
+Measures [442fdcdd-2543-5da2-b0f3-8c86c306513e]
+Missings [e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28]
+Mmap [a63ad114-7e13-5084-954f-fe012c677804]
+Mocking [78c3b35d-d492-501b-9361-3d52fe80e533]
+MozillaCACerts_jll [14a3606d-f60d-562e-9121-12d972cd8159]
+NamedArrays [86f7a689-2022-50b4-a561-43c23ac3c673]
+NetworkOptions [ca575930-c2e3-43a9-ace4-1e988b2c1908]
+OffsetArrays [6fe1bfb0-de20-5000-8ca7-80f57d26f881]
+OpenLibm_jll [05823500-19ac-5b8b-9628-191a04bc5112]
+OpenSSL_jll [458c3c95-2e84-50aa-8efc-19380b2a3a95]
+OpenSpecFun_jll [efe28fd5-8261-553b-a9e1-b2916fc3738e]
+OrderedCollections [bac558e1-5e72-5ebc-8fee-abe8a469f55d]
+PDMats [90014a1f-27ba-587c-ab20-58faa44d9150]
+Parsers [69de0a69-1ddd-5017-9359-2bf0b02dc9f0]
+Pkg [44cfe95a-1eb2-52ea-b672-e2afdf69b78f]
+PooledArrays [2dfb63ee-cc39-5dd5-95bd-886bf059d720]
+PrecompileTools [aea7be01-6a6a-4083-8856-8a6e6704d82a]
+Preferences [21216c6a-2e73-6563-6e65-726566657250]
+PrettyTables [08abe8d2-0d0c-5749-adfa-8a2ac140af0d]
+Printf [de0858da-6303-5e67-8744-51eddeeeb8d7]
+Profile [9abbd945-dff8-562f-b5e8-e1ebf5ef1b79]
+ProgressBars [49802e3a-d2f1-5c88-81d8-b72133a6f568]
+PtrArrays [43287f4e-b6f4-7ad1-bb20-aadabca52c3d]
+PyCall [438e738f-606a-5dbb-bf0a-cddfbfd45ab0]
+PyPlot [d330b81b-6aea-500a-939a-2ce795aea3ee]
+QuadGK [1fd47b50-473d-5c70-9696-f719f8f3bcdc]
+REPL [3fa0cd96-eef1-5676-8a61-b3b8758bbffb]
+Random [9a3f8284-a2c9-5f02-9a11-845980a1fd5c]
+Ratios [c84ed2f1-dad5-54f0-aa8e-dbefe2724439]
+RecipesBase [3cdcf5f2-1ef4-517c-9805-6587b60abb01]
+Reexport [189a3867-3050-52da-a836-e630ba90ab69]
+Requires [ae029012-a4dd-5104-9daa-d747884805df]
+Revise [295af30f-e4ad-537b-8983-00126c2a3abe]
+Rmath [79098fc4-a85e-5d69-aa6a-4863f24498fa]
+Rmath_jll [f50d1b31-88e8-58de-be2c-1cc44531875f]
+SHA [ea8e919c-243c-51af-8825-aaa63cd721ce]
+Scratch [6c6a2e73-6563-6170-7368-637461726353]
+SentinelArrays [91c51154-3ec4-41a3-a24f-3f23e20d615c]
+Serialization [9e88b42a-f829-5b0c-bbe9-9e923198166b]
+SharedArrays [1a1011a3-84de-559e-8e89-a11a2f7dc383]
+Sockets [6462fe0b-24de-5631-8697-dd941f90decc]
+SortingAlgorithms [a2af1166-a08f-5f64-846c-94a0d3cef48c]
+SparseArrays [2f01184e-e22b-5df5-ae63-d93ebab69eaf]
+SpecialFunctions [276daf66-3868-5448-9aa4-cd146d93841b]
+StaticArrays [90137ffa-7385-5640-81b9-e52037218182]
+StaticArraysCore [1e83bf80-4336-4d27-bf5d-d5a4f845583c]
+Statistics [10745b16-79ce-11e8-11f9-7d13ad32a3b2]
+StatsAPI [82ae8749-77ed-4fe6-ae5f-f523153014b0]
+StatsBase [2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91]
+StatsFuns [4c63d2b9-4356-54db-8cca-17b64c39e42c]
+StringManipulation [892a3eda-7b42-436c-8928-eab12a02cf0e]
+SuiteSparse [4607b0f0-06f3-5cda-b6b1-a6196a1729e9]
+TOML [fa267f1f-6049-4f14-aa54-33bafae1ed76]
+TZJData [dc5dba14-91b3-4cab-a142-028a31da12f7]
+TableTraits [3783bdb8-4a98-5b6b-af9a-565f29a5fe9c]
+Tables [bd369af6-aec1-5ad0-b16a-f7cc5008161c]
+Tar [a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e]
+Test [8dfed614-e22c-5e08-85e1-65c5234f0b40]
+TextWrap [b718987f-49a8-5099-9789-dcd902bef87d]
+TickTock [9ff05d80-102d-5586-aa04-3a8bd1a90d20]
+TimeZones [f269a46b-ccf7-5d73-abea-4c690281aa53]
+TranscodingStreams [3bb67fe8-82b1-5028-8e26-92a6c54297fa]
+UUIDs [cf7118a7-6976-5b1a-9a39-7adc72f591a4]
+Unicode [4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5]
+VersionParsing [81def892-9a0e-5fdd-b105-ffc91e053289]
+WeakRefStrings [ea10d353-3f73-51f8-a26c-33c1cb351aa5]
+WoodburyMatrices [efce3f68-66dc-5838-9240-27a6d6f5f9b6]
+WorkerUtilities [76eceee3-57b5-4d4a-8e66-0e911cebbf60]
+Zlib_jll [83775a58-1f1d-513f-b197-d71354ab007a]
+Zstd_jll [3161d3a3-bdf6-5164-811a-617609db77b4]
+glmnet_jll [78c6b45d-5eaf-5d68-bcfb-a5a2cb06c27f]
+nghttp2_jll [8e850ede-7688-5339-a07c-302acd2aaf8d]
+p7zip_jll [3f19e933-33d8-53b3-aaab-bd5110c3b7a0]
diff --git a/evaluation/R/ComparePrior.R b/evaluation/R/ComparePrior.R
new file mode 100755
index 0000000..e0163aa
--- /dev/null
+++ b/evaluation/R/ComparePrior.R
@@ -0,0 +1,33 @@
+library(ggplot2)
+
+prior1 <- read.table("/data/miraldiNB/Katko/Projects/Barski_CD4_Multiome/Outs/TRAC_loop/Prior/Prior_sum.tsv")
+prior2 <- read.table("/data/miraldiNB/Katko/Projects/Barski_CD4_Multiome/Outs/Seurat/Prior/MEMT_050723_FIMOp5_b.tsv")
+
+index <- intersect(rownames(prior1), rownames(prior2))
+prior1 <- prior1[index,]
+prior2 <- prior2[index,]
+
+prior_df <- as.matrix(rep(0, length(colnames(prior1))))
+prior_df <- cbind(prior_df, rep(0, length(colnames(prior2))))
+rownames(prior_df) <- colnames(prior1)
+colnames(prior_df) <- c("TRAC","Body")
+
+for(i in 1:length(colnames(prior1))){
+ prior1_targets <- length(which(prior1[,i] > 0))
+ prior2_targets <- length(which(prior2[,i] > 0))
+ prior_df[i,1] <- prior1_targets
+ prior_df[i,2] <- prior2_targets
+}
+
+prior_df <- as.data.frame(prior_df)
+pdf("Prior_Scatter.pdf", width = 8, height = 8)
+ggplot(prior_df, aes(x = TRAC, y = Body)) + geom_point() + geom_abline(slope=1, intercept=0)
+dev.off()
+
+prior_df_hist <- data.frame(Freq = c(prior_df[,1], prior_df[,2]))
+prior_df_hist$TF <- c(rownames(prior_df), rownames(prior_df))
+prior_df_hist$Prior <- c(rep("Prior1", length(colnames(prior1))), rep("Prior2", length(colnames(prior2))))
+
+pdf("Prior_hist.pdf", width = 8, height = 8)
+ggplot(prior_df_hist, aes(x = Freq, color = Prior)) + geom_histogram(fill = "white", position = "identity", alpha = 0.7)
+dev.off()
\ No newline at end of file
diff --git a/evaluation/R/DeltaTFA.R b/evaluation/R/DeltaTFA.R
new file mode 100755
index 0000000..50ac71d
--- /dev/null
+++ b/evaluation/R/DeltaTFA.R
@@ -0,0 +1,376 @@
+library(biomaRt)
+library(stringr)
+library(DESeq2)
+
+
+counts <- read.table("GSE271788_dedup_counts.txt", header = T)
+rownames(counts) <- counts[,1]
+counts <- counts[,7:length(counts)]
+
+new_names <- str_extract(colnames(counts), "Donor_\\d+_[A-Za-z0-9]+")
+new_names <- sapply(strsplit(new_names, "_"), function(x) paste(x[3], x[1], x[2], sep = "_"))
+colnames(counts) <- new_names
+
+ensembl_ids_clean <- sub("\\..*", "", rownames(counts))
+gtf <- import("gencode.v48.chr_patch_hapl_scaff.annotation.gtf.gz")
+# Keep only gene entries
+genes <- gtf[gtf$type == "gene"]
+
+# Build mapping table
+mapping <- data.frame(
+ ensembl_gene_id = gsub("\\..*", "", genes$gene_id), # remove version
+ gene_name = genes$gene_name
+) %>%
+ distinct()
+symbol_map <- mapping$gene_name[match(gene_ids, mapping$ensembl_gene_id)]
+symbol_map[is.na(symbol_map)] <- gene_ids[is.na(symbol_map)]
+
+
+gene_ids <- gsub("\\..*", "", rownames(counts)) # remove versions
+symbol_map <- mapping$gene_name[match(gene_ids, mapping$ensembl_gene_id)]
+rownames(counts) <- ifelse(is.na(symbol_map), gene_ids, symbol_map)
+counts <- counts[!grepl("^ENSG", rownames(counts)), ]
+
+# Build metadata dataframe
+colnames(counts) <- make.unique(colnames(counts))
+
+clean_names <- sub("\\.\\d+$", "", colnames(counts))
+
+# Split by underscore
+split_info <- do.call(rbind, strsplit(clean_names, "_"))
+
+# Build metadata dataframe
+metadata <- data.frame(
+ TF = split_info[, 1],
+ Donor = paste(split_info[, 2], split_info[, 3], sep = "_"),
+ row.names = colnames(counts),
+ stringsAsFactors = FALSE
+)
+
+dds <- DESeqDataSetFromMatrix(
+ countData = counts,
+ colData = metadata,
+ design = ~ Donor + TF
+)
+dds <- DESeq(dds)
+
+#cor_matrix <- dcast(network, Gene ~ TF, value.var = "signedQuantile", fill = 0)
+
+
+### Plot TFA changes
+library(tidyverse)
+bulk_cor <- read.table("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/FullPseudobulk/FullPseudobulk/lambda0p5_200totSS_20tfsPerGene_subsamplePCT63/Combined/TFA_cor.txt")
+bulk_quant <- read.table("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/FullPseudobulk/FullPseudobulk/lambda0p5_200totSS_20tfsPerGene_subsamplePCT63/Combined/TFA_quant.txt")
+bulk_sign <- read.table("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/FullPseudobulk/FullPseudobulk/lambda0p5_200totSS_20tfsPerGene_subsamplePCT63/Combined/TFA_sign.txt")
+sc_cor <- read.table("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/scSubsampleFraction/lambda0p5_200totSS_20tfsPerGene_subsamplePCT10/Combined/TFA_cor.txt")
+sc_quant <- read.table("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/scSubsampleFraction/lambda0p5_200totSS_20tfsPerGene_subsamplePCT10/Combined/TFA_quant.txt")
+sc_sign <- read.table("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/scSubsampleFraction/lambda0p5_200totSS_20tfsPerGene_subsamplePCT10/Combined/TFA_binary.txt")
+tfa_matrix <- list(bulk_cor, bulk_quant, bulk_sign, sc_cor, sc_quant, sc_sign)
+names(tfa_matrix) <- c("bulk_cor","bulk_quant","bulk_sign","sc_cor","sc_quant","sc_sign")
+
+tf_list <- c(
+ "AIRE", "BACH2", "BCL11B", "BPTF", "CLOCK", "EPAS1", "ETS1", "FOXK1", "FOXP1", "FOXP3",
+ "GATA3", "GFI1", "HIVEP2", "IKZF1", "IRF1", "IRF2", "IRF4", "IRF7", "IRF9", "KLF2",
+ "KMT2A", "MBD2", "MYB", "NFE2L2", "NFAT5", "NFKB1", "NFKB2", "RELA",
+ "RELB", "REL", "RFX5", "RORC", "SETDB1", "SREBF1", "STAT1", "STAT2", "STAT3", "STAT5A",
+ "STAT5B", "TBX21", "TCF3", "TP53", "YBX1", "YBX3", "YY1", "ZBTB14", "ZFP3", "ZKSCAN1",
+ "ZNF329", "ZNF341", "ZNF791"
+)
+tf_list <- tf_list[which(tf_list %in% rownames(tfa_matrix[[1]]))]
+library(dplyr)
+library(tidyr)
+library(ggplot2)
+library(purrr)
+
+
+analyse_one_matrix <- function(tfa_matrix,
+ tf_list,
+ pdf_name,
+ sig_cutoffs = c(`***` = 0.001,
+ `**` = 0.01,
+ `*` = 0.05)) {
+
+ p_to_symbol <- function(p) {
+ stars <- names(sig_cutoffs)[which(p < sig_cutoffs)]
+ if (length(stars) == 0) "ns" else stars[1]
+ }
+
+ res <- purrr::map_dfr(tf_list, function(tf) {
+
+ #--- pull the row safely ---------------------------------------------------
+ if (!tf %in% rownames(tfa_matrix)) {
+ warning("TF ", tf, " not present in matrix; skipping.")
+ return(tibble::tibble(TF = tf, p.value = NA_real_,
+ signif = "NA", direction = "NA", plot = list(NULL)))
+ }
+
+ tf_activity <- tfa_matrix[tf, ] # numeric vector
+ tf_df <- tibble::tibble(Sample = names(tf_activity),
+ Activity = as.numeric(tf_activity)) %>%
+ dplyr::mutate(
+ Knockout_TF = sub("_Donor_.*", "", Sample),
+ Donor = sub(".*_Donor_", "", Sample) %>% sub("\\..*", "", .)
+ )
+
+ #--- average per donor -----------------------------------------------------
+ controls <- tf_df %>% dplyr::filter(Knockout_TF == "AAVS1") %>%
+ dplyr::group_by(Donor) %>%
+ dplyr::summarise(Control = mean(Activity), .groups = "drop")
+
+ knockouts <- tf_df %>% dplyr::filter(Knockout_TF == tf) %>%
+ dplyr::group_by(Donor) %>%
+ dplyr::summarise(Knockout = mean(Activity), .groups = "drop")
+
+ paired <- dplyr::inner_join(controls, knockouts, by = "Donor")
+
+ if (nrow(paired) < 2) {
+ warning("TF ", tf, ": fewer than 2 paired donors – skipped")
+ return(tibble::tibble(TF = tf, p.value = NA_real_,
+ signif = "NA", direction = "NA", plot = list(NULL)))
+ }
+
+ #--- stats -----------------------------------------------------------------
+ t_res <- stats::t.test(paired$Knockout, paired$Control, paired = TRUE)
+ p_val <- t_res$p.value
+ direction <- ifelse(mean(paired$Knockout - paired$Control) > 0, "Up", "Down")
+ signif_sym <- p_to_symbol(p_val)
+
+ #--- plot ------------------------------------------------------------------
+ long <- paired %>%
+ tidyr::pivot_longer(Control:Knockout,
+ names_to = "Condition",
+ values_to = "Activity")
+
+ g <- ggplot2::ggplot(long, ggplot2::aes(Condition, Activity, group = Donor)) +
+ ggplot2::geom_line(ggplot2::aes(color = Donor), linewidth = 1.1, alpha = 0.8) +
+ ggplot2::geom_point(ggplot2::aes(color = Donor), size = 3) +
+ ggplot2::theme_minimal(base_size = 14) +
+ ggplot2::labs(
+ title = paste(tf, "Activity Before and After Knockout"),
+ subtitle = paste0("p = ", signif(p_val, 3),
+ " (", signif_sym, ", ", direction, ")"),
+ x = "Condition", y = "Estimated TF Activity", color = "Donor") +
+ ggplot2::theme(plot.title = ggplot2::element_text(hjust = 0.5, face = "bold"),
+ plot.subtitle = ggplot2::element_text(hjust = 0.5))
+
+ tibble::tibble(TF = tf,
+ p.value = p_val,
+ signif = signif_sym,
+ direction = direction,
+ plot = list(g))
+ })
+
+ #--- write one PDF with all TFs ---------------------------------------------
+ grDevices::pdf(pdf_name, width = 8, height = 10)
+ purrr::walk(res$plot[!purrr::map_lgl(res$plot, is.null)], print)
+ grDevices::dev.off()
+
+ res # keep the plot column this time
+}
+
+################################################################
+## 2. MANY-MATRIX WRAPPER (robust, no imap_dfr required) ##
+################################################################
+analyse_many_matrices <- function(tfa_list, # named list or named file vector
+ tf_list,
+ out_dir = ".") {
+
+ # read .rds files if character vector of paths was supplied
+ if (is.character(tfa_list) && !is.matrix(tfa_list[[1]])) {
+ tfa_list <- lapply(tfa_list, readRDS)
+ names(tfa_list) <- basename(names(tfa_list)) # keep vector names
+ }
+
+ if (is.null(names(tfa_list)) || any(names(tfa_list) == ""))
+ stop("tfa_list must be a *named* list or vector so datasets can be labelled.")
+
+ dir.create(out_dir, showWarnings = FALSE, recursive = TRUE)
+
+ results <- vector("list", length(tfa_list))
+ out_i <- 1L
+
+ for (mat_name in names(tfa_list)) {
+
+ mat <- tfa_list[[mat_name]]
+ if (!(is.matrix(mat) || is.data.frame(mat))) {
+ warning("Skipping '", mat_name, "': not a matrix/data.frame")
+ next
+ }
+
+ pdf_file <- file.path(out_dir,
+ paste0("TF_activity_changes_", mat_name, ".pdf"))
+
+ message("Processing ", mat_name, " …")
+ stats_one <- tryCatch(
+ analyse_one_matrix(mat, tf_list, pdf_name = pdf_file),
+ error = function(e) {
+ warning("Failed on ", mat_name, ": ", conditionMessage(e))
+ NULL
+ }
+ )
+
+ if (is.null(stats_one)) next # skip failures
+
+ stats_one <- tibble::as_tibble(stats_one) # force tibble
+ stats_one <- dplyr::mutate(stats_one, dataset = mat_name, .before = 1)
+
+ results[[out_i]] <- stats_one
+ out_i <- out_i + 1L
+ }
+
+ dplyr::bind_rows(results[seq_len(out_i - 1L)])
+}
+
+
+all_stats <- analyse_many_matrices(tfa_matrix,
+ tf_list,
+ out_dir = "plots") # PDFs in ./plots
+
+## View or export the combined statistics:
+print(all_stats)
+library(ComplexHeatmap)
+library(circlize)
+library(grid)
+library(dplyr)
+library(tidyr)
+
+pct_expressing <- function(obj,
+ genes,
+ assay = "RNA",
+ slot = "data") {
+
+ expr <- Seurat::GetAssayData(obj[[assay]], slot = slot) # genes × cells
+ vec <- vapply(genes, function(g) {
+ if (!g %in% rownames(expr)) return(NA_real_)
+ mean(expr[g, ] > 0) * 100 # % of cells
+ }, numeric(1))
+ names(vec) <- genes
+ vec
+}
+
+DefaultAssay(obj) <- "RNA" # if not already set
+pct_vec <- pct_expressing(obj, tf_list)
+
+score_mat <- all_stats %>%
+ mutate(score = ifelse(
+ is.na(p.value),
+ NA_real_,
+ -log10(p.value) * ifelse(direction == "Up", 1, -1))
+ ) %>%
+ select(TF, dataset, score) %>%
+ pivot_wider(names_from = dataset, values_from = score) %>%
+ as.data.frame()
+
+rownames(score_mat) <- score_mat$TF
+score_mat <- as.matrix(score_mat[, -1, drop = FALSE])
+
+star_mat <- all_stats %>%
+ mutate(stars = case_when(
+ is.na(p.value) ~ "",
+ p.value < 0.001 ~ "***",
+ p.value < 0.01 ~ "**",
+ p.value < 0.10 ~ "*",
+ TRUE ~ ""
+ )) %>%
+ select(TF, dataset, stars) %>%
+ pivot_wider(names_from = dataset, values_from = stars) %>%
+ as.data.frame()
+
+rownames(star_mat) <- star_mat$TF
+star_mat <- as.matrix(star_mat[, -1, drop = FALSE])
+
+deg_counts <- setNames(rep(0, nrow(score_mat)), rownames(score_mat))
+deg_counts[names(target_number)] <- target_number # overwrite where present
+
+# ---- LEFT annotation: % of cells expressing ----------------------------------
+row_anno_left <- rowAnnotation(
+ `% Cells\nExpressing` = anno_barplot(
+ pct_vec[rownames(score_mat)], # ensure correct order
+ gp = gpar(fill = "steelblue", col = NA),
+ bar_width = 0.85,
+ border = FALSE,
+ axis_param = list(at = c(0, 50, 100), gp = gpar(fontsize = 7))
+ ),
+ annotation_name_side = "top",
+ annotation_name_gp = gpar(fontsize = 9, fontface = "bold"),
+ width = unit(2.4, "cm")
+)
+
+# ---- RIGHT annotation: # of DE genes -----------------------------------------
+row_anno_right <- rowAnnotation(
+ `# DE genes` = anno_barplot(
+ deg_counts,
+ gp = gpar(fill = "grey40", col = NA),
+ bar_width = 0.85,
+ border = FALSE,
+ axis_param = list(at = c(0, max(deg_counts)), gp = gpar(fontsize = 7))
+ ),
+ annotation_name_side = "top",
+ annotation_name_gp = gpar(fontsize = 9, fontface = "bold"),
+ width = unit(2.3, "cm")
+)
+
+max_abs <- max(abs(score_mat), na.rm = TRUE)
+col_fun <- circlize::colorRamp2(c(-max_abs, 0, max_abs),
+ c("navy", "white", "firebrick"))
+text_col <- function(fill) {
+ rgb <- col2rgb(fill) / 255
+ ifelse(0.299*rgb[1] + 0.587*rgb[2] + 0.114*rgb[3] < 0.5, "white", "black")
+}
+
+row_labels <- rowAnnotation(
+ TF = anno_text(
+ rownames(score_mat),
+ gp = gpar(fontsize = 9),
+ just = "left",
+ location = 0.5 # centred vertically in each cell
+ ),
+ width = unit(2.2, "cm"), # reserve space for the longest name
+ show_annotation_name = FALSE
+)
+
+###############################################################################
+## 2. Core heat-map (row names turned OFF) ##################################
+###############################################################################
+ht_core <- Heatmap(
+ score_mat,
+ name = "-log10(p)",
+ col = col_fun,
+ na_col = "grey90",
+ border = TRUE,
+ row_km = 3,
+
+ show_row_names = FALSE, # <── row names handled by row_labels
+ column_names_gp = gpar(fontsize = 9),
+
+ heatmap_legend_param = list(
+ title_gp = gpar(fontsize = 10, fontface = "bold"),
+ labels_gp = gpar(fontsize = 9)
+ ),
+ cell_fun = function(j, i, x, y, width, height, fill) {
+ lab <- star_mat[i, j]
+ if (nzchar(lab)) {
+ grid.text(
+ lab, x, y,
+ gp = gpar(fontsize = 8,
+ fontface = "bold",
+ col = text_col(fill))
+ )
+ }
+ }
+)
+
+###############################################################################
+## 3. Assemble LEFT + CORE + LABELS + RIGHT ##########################
+###############################################################################
+ht_full <- row_anno_left + ht_core + row_labels + row_anno_right
+# %-expressing matrix TF names #DE genes
+
+###############################################################################
+## 4. Draw / save ###########################################################
+###############################################################################
+pdf("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/scSubsampleFraction/lambda0p5_200totSS_20tfsPerGene_subsamplePCT10/Combined/plots//TF_heatmap_with_two_bars_and_labels.pdf", width = 7.5, height = 9)
+draw(ht_full,
+ heatmap_legend_side = "right",
+ annotation_legend_side = "right")
+dev.off()
\ No newline at end of file
diff --git a/evaluation/R/GRN_Tfh10_PR_heatmap.R b/evaluation/R/GRN_Tfh10_PR_heatmap.R
new file mode 100755
index 0000000..83f420e
--- /dev/null
+++ b/evaluation/R/GRN_Tfh10_PR_heatmap.R
@@ -0,0 +1,444 @@
+# GRN_Tfh10_PR_heatmap_final2.R
+# Heatmap of PR results
+rm(list=ls())
+options(stringsAsFactors=FALSE)
+set.seed(42)
+suppressPackageStartupMessages({
+library(ggplot2)
+library(scales)
+})
+source('/data/miraldiNB/Michael/Scripts/GRN/utils_grn_analysis.R')
+source('/data/miraldiNB/wayman/scripts/net_utils.R')
+source('/data/miraldiNB/wayman/scripts/clustering_utils.R')
+
+####### INPUTS #######
+
+# Output directory
+dir_out <- '/data/miraldiNB/Michael/mCD4T_Wayman/Figures/Tfh10_S4B/2p5PCT'
+dir.create(dir_out, showWarnings=FALSE, recursive=TRUE)
+
+# File save name
+file_save <- 'Tfh10_S4'
+
+# Gene target file
+file_targ <- "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/targRegLists/targetGenes_names.txt"
+
+# TF list
+file_tf <- "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/targRegLists/potRegs_names.txt"
+
+# list of networks (in order)
+file_grn <- c(
+
+# No Prior : lambdaBias=1 + tfaOpt='_TFmRNA' + ATAC_KOprior
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # NoPrior (TFmRNA + ATAC)
+
+# ATAC Prior
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # +mRNA (ATAC)
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # ++mRNA (ATAC)
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # TFA
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # +TFA (ATAC)
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # ++TFA (ATAC)
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/Combined/combined.tsv", # +Combine (ATAC)
+
+
+# ATAC Prior (sc-Inferelator)
+#
+
+# ATAC+ChIP Prior
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # +mRNA (ATAC + ChIP)
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # ++mRNA (ATAC + ChIP)
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # TFA
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # +TFA (ATAC + ChIP)
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # ++TFA (ATAC + ChIP)
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/Combined/combined.tsv", # +Combine (ATAC + ChIP)
+
+# ATAC + ChIP Prior (sc- Inferelator)
+#
+
+# ATAC+KO Prior
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # +mRNA (ATAC + KO)
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # ++mRNA (ATAC + KO)
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # TFA
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # +TFA (ATAC + KO)
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # ++TFA (ATAC + KO)
+"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/Combined/combined.tsv", # +Combine (ATAC + KO)
+
+# ATAC + KO Prior (sc- Inferelator)
+#
+
+#Cell Oracle
+# "/data/miraldiNB/Michael/mCD4T_Wayman/CellOracle/GRNs/Unprocessed/Combined/Old/max_pCutoff_0.1pct_combined_GRN_3cols.tsv",
+"/data/miraldiNB/Michael/mCD4T_Wayman/CellOracle/GRNs/Unprocessed/CombinedGRNs/max/max_pCut0.1_meanEdgesPerGene20_combined_GRN.tsv",
+
+# Scenic Plus
+"/data/miraldiNB/Michael/mCD4T_Wayman/scenicPlus/eRegulons_direct_filtered.tsv", # Direct
+"/data/miraldiNB/Michael/mCD4T_Wayman/scenicPlus/eRegulons_extended_filtered.tsv", # Extended
+
+#SupirFactor
+"/data/miraldiNB/Michael/mCD4T_Wayman/SupirFactor/Hierarchical_20250212_2043/GRN_Hierarchical_long.tsv", # Hierarchical
+"/data/miraldiNB/Michael/mCD4T_Wayman/SupirFactor/Shallow_20250330_0212/GRN_Shallow_long.tsv" # Shallow
+)
+
+# network names
+name_grn <- c(
+
+'No Prior',
+
+'ATAC mRNA +',
+'ATAC mRNA ++',
+'ATAC TFA',
+'ATAC TFA +',
+'ATAC TFA ++',
+'ATAC Combined +',
+
+# 'sc ATAC mRNA +',
+# 'sc ATAC TFA +',
+# 'sc ATAC Combined +',
+
+'ChIP mRNA +',
+'ChIP mRNA ++',
+'ChIP TFA',
+'ChIP TFA +',
+'ChIP TFA ++',
+'ChIP Combined +',
+
+# 'sc ChIP mRNA +',
+# 'sc ChIP TFA +',
+# 'sc ChIP Combined +',
+
+'KO mRNA +',
+'KO mRNA ++',
+'KO TFA',
+'KO TFA +',
+'KO TFA ++',
+'KO Combined +',
+
+# 'sc KO mRNA +',
+# 'sc KO TFA +',
+# 'sc KO Combined +',
+
+'Cell Oracle',
+# 'Cell Oracle Mean',
+
+'SCENIC+ Dir',
+'SCENIC+ Ext',
+
+'SupirFactor H',
+'SupirFactor Sh'
+
+)
+
+# network type
+type_grn <- c(
+
+'No Prior',
+
+'ATAC',
+'ATAC',
+'ATAC',
+'ATAC',
+'ATAC',
+'ATAC',
+
+# 'ATAC',
+# 'ATAC',
+# 'ATAC',
+
+'ChIP',
+'ChIP',
+'ChIP',
+'ChIP',
+'ChIP',
+'ChIP',
+
+# 'ChIP',
+# 'ChIP',
+# 'ChIP',
+
+'KO',
+'KO',
+'KO',
+'KO',
+'KO',
+'KO',
+
+# 'KO',
+# 'KO',
+# 'KO',
+
+'Cell Oracle',
+# 'Cell Oracle',
+
+'SCENIC+',
+'SCENIC+',
+
+'SupirFactor',
+'SupirFactor'
+)
+
+# Gold standard names
+name_gs <- c(
+'ChIP',
+'KO',
+'KC'
+)
+
+# List of gold standards
+file_gs <- c(
+ # '/data/miraldiNB/wayman/projects/Tfh10/outs/202204/GRN/GS/TF_binding/priors/FDR5_Rank50/prior_ChIP_Thelper_Miraldi2019Th17_combine_FDR5_Rank50_sp.tsv',
+ # '/data/miraldiNB/wayman/projects/Tfh10/outs/202204/GRN/GS/RNA/priors/prior_RNA_Thelper_Miraldi2019Th17_combine_Log2FC0p5_FDR20_Rank50_sp.tsv',
+ # '/data/miraldiNB/wayman/projects/Tfh10/outs/202204/GRN/GS/KC/prior_KC_Thelper_Miraldi2019Th17_Rank100_sp.tsv'
+ "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/goldStandards/prior_ChIP_Thelper_Miraldi2019Th17_combine_FDR5_Rank50_sp.tsv",
+ "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/goldStandards/prior_TF_KO_RNA_Thelper_Miraldi2019Th17_combine_Log2FC0p5_FDR20_Rank50_sp.tsv",
+ "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/goldStandards/prior_KC_Thelper_Miraldi2019Th17_Rank100_sp.tsv"
+)
+
+# gold standard save file name
+save_gs <- ''
+
+# params
+min_targ_gs <- 20 # gold standard min target cutoff
+cutoff_rec <- 0.025 # recall cutoff
+range_fc <- c(-2,2) # color scale range for log2fc heatmap
+breaks_fc <- c(-2,-1,0,1,2)
+heat_h <- 8 # heatmap height
+# heat_w <- 15 # heatmap width
+heat_w <- 21 # heatmap width
+box_w <- 6 # boxplot width
+
+######################
+
+# Output file label
+label_pcut <- 100*cutoff_rec
+file_save <- if (nzchar(save_gs)) paste0(file_save, '_', save_gs) else file_save
+
+
+# Load tf list
+tf <- readLines(file_tf)
+
+# Calc precision-recall
+df_pr <- NULL
+for (ix in 1:length(file_grn)){
+
+ print(paste0('Process GRN: ',ix,'/',length(file_grn)))
+ for (jx in 1:length(file_gs)){
+
+ # # Filter gold standard TFs
+ gs <- read.delim(file_gs[jx], header=TRUE, sep='\t')
+ gs <- subset(gs, TF %in% tf)
+ keep_tf <- names(which(table(gs$TF) >= min_targ_gs))
+ gs <- subset(gs, TF %in% keep_tf)
+
+ # All PR
+ curr_tf <- sort(unique(gs$TF))
+ pr <- grn_pr(grn=file_grn[ix], gs=gs, gene_tar=file_targ, tf=curr_tf)
+ prec <- pr$P; prec <- c(prec, pr$Rand)
+ recall <- pr$R; recall <- c(recall, 1)
+
+ # Precision at cutoff
+ slope_pr <- (prec[which(recall >= cutoff_rec)[1]] - prec[which(recall >= cutoff_rec)[1]-1])/
+ (recall[which(recall >= cutoff_rec)[1]] - recall[which(recall >= cutoff_rec)[1]-1])
+ pcut <- prec[which(recall >= cutoff_rec)[1]-1] +
+ slope_pr*(cutoff_rec-recall[which(recall >= cutoff_rec)[1]-1])
+ curr_df <- data.frame(GRN=name_grn[ix], Type=type_grn[ix], GS=name_gs[jx],
+ TF='All TF', Pcut=pcut, Rand=pr$Rand)
+ df_pr <- rbind(df_pr, curr_df)
+
+ for (kx in curr_tf){
+
+ # pr <- grn_pr(grn=file_grn[ix], gs=file_gs[jx], gene_tar=file_targ, tf=kx)
+ pr <- grn_pr(grn=file_grn[ix], gs=gs, gene_tar=file_targ, tf=kx)
+ if (!is.null(pr)){
+ prec <- pr$P; prec <- c(prec, pr$Rand)
+ recall <- pr$R; recall <- c(recall, 1)
+
+ # precision at cutoff
+ slope_pr <- (prec[which(recall >= cutoff_rec)[1]] - prec[which(recall >= cutoff_rec)[1]-1])/
+ (recall[which(recall >= cutoff_rec)[1]] - recall[which(recall >= cutoff_rec)[1]-1])
+ pcut <- prec[which(recall >= cutoff_rec)[1]-1] +
+ slope_pr*(cutoff_rec-recall[which(recall >= cutoff_rec)[1]-1])
+ curr_df <- data.frame(GRN=name_grn[ix], Type=type_grn[ix], GS=name_gs[jx],
+ TF=kx, Pcut=pcut, Rand=pr$Rand)
+ df_pr <- rbind(df_pr, curr_df)
+ }
+ }
+ }
+}
+df_pr$PcutLog2FC <- log2(df_pr$Pcut/df_pr$Rand)
+df_pr$GS_TF <- paste(df_pr$GS, df_pr$TF, sep='_')
+
+# order gold standards
+order_gs <- NULL
+for (ix in name_gs){
+ order_gs <- c(order_gs, paste(ix,'All TF',sep='_'))
+ curr_pr <- subset(df_pr, GS==ix & TF != 'All TF')
+ mat <- net_sparse_2_full(curr_pr[,c('GS_TF','GRN','PcutLog2FC')])
+ curr_order <- colnames(mat)[cluster_hierarchical(mat)]
+ order_gs <- c(order_gs, curr_order)
+}
+df_pr$GS_TF <- factor(df_pr$GS_TF, levels=order_gs)
+
+# order networks
+df_pr$GRN <- factor(df_pr$GRN, levels=rev(name_grn))
+
+
+# Axis label
+label_gs <- order_gs
+for(ix in name_gs){
+ label_gs <- gsub(paste0(ix,'_'),'',label_gs)
+}
+face_gs <- rep('plain',length(label_gs))
+face_gs[which(label_gs=='All TF')] <- 'bold'
+
+tf_ko <- unique(subset(df_pr, GS=='KO' & TF != 'All TF')$TF)
+tf_chip <- unique(subset(df_pr, GS=='ChIP' & TF != 'All TF')$TF)
+tf_int <- intersect(tf_ko, tf_chip)
+for (ix in tf_int){
+ face_gs[which(label_gs==ix)[1]] <- 'italic'
+}
+
+# Lines separating gold standards
+# vline_sep <- ceiling(cumsum(table(df_pr$GS)[name_gs]/length(name_grn)))+0.5
+tf_counts <- tapply(df_pr$TF, df_pr$GS, function(x) length(unique(x)))
+vline_sep <- cumsum(tf_counts[name_gs]) + 0.5
+vline_sep <- vline_sep[1:(length(vline_sep)-1)]
+
+# lines separating network types
+order_type <- rev(unique(type_grn))
+hline_sep <- cumsum(table(type_grn)[order_type])+0.5
+hline_sep <- hline_sep[1:(length(hline_sep)-1)]
+
+color_pal <- c(low='dodgerblue3', 'white', high='firebrick3') # c('#2c7bb6','#abd9e9','white','#d7191c','#b10026')
+# Heatmap precision at cutoff
+print('Heatmap precision at cutoff')
+ggplot(df_pr, aes(x=GS_TF, y=GRN, fill=PcutLog2FC)) +
+ geom_tile() +
+ # geom_tile(data = subset(df_pr, !is.na(PcutLog2FC))) + # Removes NA tiles
+ # geom_tile(data = subset(df_pr, is.na(PcutLog2FC)), fill = "lightgrey") + # Adds NA tiles without borders
+ geom_vline(xintercept=vline_sep, lwd=0.8, colour='gray30') +
+ geom_hline(yintercept=hline_sep, lwd=0.8, colour='gray30') +
+ labs(x=' ', y=' ', fill='Log2 FC') +
+ theme(axis.text=element_text(size=12,face='plain')) +
+ theme(axis.title=element_text(size=14,face='plain')) +
+ theme(axis.text.x=element_text(angle=45, hjust=1)) +
+ theme(axis.text.x=element_text(face=face_gs)) +
+ # scale_fill_gradient2(low='dodgerblue3', high='firebrick3', limits=range_fc, na.value="lightgrey",oob=squish) +
+ scale_fill_gradientn(
+ # colors=c(low='dodgerblue3', high='firebrick3'),
+ limits=range_fc,
+ colors= color_pal,
+ breaks=breaks_fc,
+ na.value="lightgrey",oob=squish
+ )+
+ scale_x_discrete(labels=label_gs)
+file_out <- file.path(dir_out, paste0('heat_',file_save,'_PRcut',label_pcut,'.pdf'))
+ggsave(file_out, height=heat_h, width=heat_w)
+
+ggplot(df_pr, aes(x=GS_TF, y=GRN, fill=PcutLog2FC)) +
+ geom_tile() +
+ # geom_tile(data = subset(df_pr, !is.na(PcutLog2FC))) + # Removes NA tiles
+ # geom_tile(data = subset(df_pr, is.na(PcutLog2FC)), fill = "lightgrey") + # Adds NA tiles without borders
+ geom_vline(xintercept=vline_sep, lwd=0.8, colour='gray30') +
+ geom_hline(yintercept=hline_sep, lwd=0.8, colour='gray30') +
+ labs(x=' ', y=' ', fill=' ') +
+ theme(axis.text=element_text(size=12,face='plain')) +
+ theme(axis.title=element_text(size=14,face='plain')) +
+ theme(axis.text.x=element_text(angle=45, hjust=1)) +
+ theme(axis.text.x=element_text(face=face_gs)) +
+ # scale_fill_gradient2(low='dodgerblue3', high='firebrick3', limits=range_fc, na.value="lightgrey",oob=squish) +
+ scale_fill_gradientn(
+ # colors=c(low='dodgerblue3', high='firebrick3'),
+ limits=range_fc,
+ colors= color_pal,
+ breaks=breaks_fc,
+ na.value="lightgrey",oob=squish
+ )+
+ scale_x_discrete(labels=label_gs) +
+ theme(legend.position='bottom', legend.text=element_text(size=16), legend.key.width=unit(1.5,'cm'))
+file_out <- file.path(dir_out, paste0('heat_',file_save,'_PRcut',label_pcut,'_2.pdf'))
+ggsave(file_out, height=heat_h, width=heat_w)
+
+# Boxplot log2fc
+print('boxplot log2fc')
+for (ix in name_gs){
+ df_pr_sub <- subset(df_pr, GS==ix & TF != 'All TF')
+ ggplot(df_pr_sub, aes(x=GRN, y=PcutLog2FC)) + geom_boxplot() + coord_flip() +
+ theme_minimal() +
+ labs(x=' ', y='Log2 Fold-Change') +
+ theme(axis.text=element_text(size=12,face='plain')) +
+ theme(axis.title=element_text(size=14,face='plain')) +
+ theme(axis.line=element_line(size=0.6, colour='grey30', linetype='solid'))
+ file_out <- file.path(dir_out, paste0('box_',file_save,'_PRcut',label_pcut,'_',ix,'.pdf'))
+ ggsave(file_out, height=heat_h, width=box_w)
+}
+
+
+
+# set regions where GS is tested on itself to NA
+df_pr_grey <- df_pr
+df_pr_grey$PcutLog2FC <- ifelse((df_pr$Type == "ChIP" & df_pr$GS != "KO") |
+ (df_pr$Type == "KO" & df_pr$GS != "ChIP"),
+ NA, df_pr$PcutLog2FC)
+
+# Heatmap precision at cutoff
+print('Heatmap precision at cutoff - greyed out GS tested regions')
+ggplot(df_pr_grey, aes(x=GS_TF, y=GRN, fill=PcutLog2FC)) +
+ geom_tile() +
+ # geom_tile(data = subset(df_pr, !is.na(PcutLog2FC))) + # Removes NA tiles
+ # geom_tile(data = subset(df_pr, is.na(PcutLog2FC)), fill = "lightgrey") + # Adds NA tiles without borders
+ geom_vline(xintercept=vline_sep, lwd=0.8, colour='gray30') +
+ geom_hline(yintercept=hline_sep, lwd=0.8, colour='gray30') +
+ labs(x=' ', y=' ', fill='Log2 FC') +
+ theme(axis.text=element_text(size=12,face='plain')) +
+ theme(axis.title=element_text(size=14,face='plain')) +
+ theme(axis.text.x=element_text(angle=45, hjust=1)) +
+ theme(axis.text.x=element_text(face=face_gs)) +
+ # scale_fill_gradient2(low='dodgerblue3', high='firebrick3', limits=range_fc, na.value="lightgrey",oob=squish) +
+ scale_fill_gradientn(
+ limits=range_fc,
+ colors= color_pal,
+ breaks=breaks_fc,
+ na.value="lightgrey",oob=squish
+ )+
+ scale_x_discrete(labels=label_gs)
+file_out <- file.path(dir_out, paste0('heat_',file_save,'_PRcut',label_pcut,'_greyed_out.pdf'))
+ggsave(file_out, height=heat_h, width=heat_w)
+
+ggplot(df_pr_grey, aes(x=GS_TF, y=GRN, fill=PcutLog2FC)) +
+ geom_tile() +
+ # geom_tile(data = subset(df_pr, !is.na(PcutLog2FC))) + # Removes NA tiles
+ # geom_tile(data = subset(df_pr, is.na(PcutLog2FC)), fill = "lightgrey") + # Adds NA tiles without borders
+ geom_vline(xintercept=vline_sep, lwd=0.8, colour='gray30') +
+ geom_hline(yintercept=hline_sep, lwd=0.8, colour='gray30') +
+ labs(x=' ', y=' ', fill=' ') +
+ theme(axis.text=element_text(size=12,face='plain')) +
+ theme(axis.title=element_text(size=14,face='plain')) +
+ theme(axis.text.x=element_text(angle=45, hjust=1)) +
+ theme(axis.text.x=element_text(face=face_gs)) +
+ # scale_fill_gradient2(low='dodgerblue3', high='firebrick3', limits=range_fc, na.value="lightgrey",oob=squish) +
+ scale_fill_gradientn(
+ # colors=c(low='dodgerblue3', high='firebrick3'),
+ limits=range_fc,
+ colors= color_pal,
+ breaks=breaks_fc,
+ na.value="lightgrey",oob=squish
+ )+
+ scale_x_discrete(labels=label_gs) +
+ theme(legend.position='bottom', legend.text=element_text(size=16), legend.key.width=unit(1.5,'cm'))
+file_out <- file.path(dir_out, paste0('heat_',file_save,'_PRcut',label_pcut,'_greyed_out_2.pdf'))
+ggsave(file_out, height=heat_h, width=heat_w)
+
+# Boxplot log2fc
+print('boxplot log2fc')
+for (ix in name_gs){
+ df_pr_sub <- subset(df_pr_grey, GS==ix & TF != 'All TF')
+ ggplot(df_pr_sub, aes(x=GRN, y=PcutLog2FC)) + geom_boxplot() + coord_flip() +
+ theme_minimal() +
+ labs(x=' ', y='Log2 Fold-Change') +
+ theme(axis.text=element_text(size=12,face='plain')) +
+ theme(axis.title=element_text(size=14,face='plain')) +
+ theme(axis.line=element_line(size=0.6, colour='grey30', linetype='solid'))
+ file_out <- file.path(dir_out, paste0('box_',file_save,'_PRcut',label_pcut,'_',ix,'_greyed_out.pdf'))
+ ggsave(file_out, height=heat_h, width=box_w)
+}
+print('DONE')
diff --git a/evaluation/R/TFA_Viz.R b/evaluation/R/TFA_Viz.R
new file mode 100755
index 0000000..8a89c1b
--- /dev/null
+++ b/evaluation/R/TFA_Viz.R
@@ -0,0 +1,245 @@
+rm(list = ls())
+options(stringsAsFactors=FALSE)
+suppressPackageStartupMessages({
+ library(ComplexHeatmap)
+ library(dplyr)
+ library(tidyverse)
+ library(circlize)
+ library(RColorBrewer)
+ library(gridtext)
+ library(readxl)
+ library(reshape2)
+ library(Seurat)
+ library(ggplot2)
+})
+
+source("/data/miraldiNB/Michael/Scripts/standardize_normalize_pseudobulk.R")
+source("/data/miraldiNB/Michael/Scripts/GSEA/GSEA_utils.R")
+
+dirOut <- "/data/miraldiNB/Michael/mCD4T_Wayman/Figures/Fig4/Test"
+dir.create(dirOut, showWarnings = F, recursive = T)
+
+
+# network <- "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.tsv"
+tfa_file <- "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/TFA.txt"
+file_k_clust_across <- NULL #file.path(dir_out, 'k_clust_across.rds')
+file_k_clust_within <- NULL #file.path(dir_out, 'k_clust_within.rds')
+
+order_celltype <- c('Tfh10','Tfh_Int','Tfh',
+ 'Tfr','cTreg','eTreg','rTreg','Treg_Rorc',
+ 'Th17','Th1','CTL_Prdm1','CTL_Bcl6',
+ #'TM_Act','TM_ISG',
+ 'TEM','TCM')
+order_age <- c('Young','Old')
+order_rep <- c('R1','R2','R3','R4')
+
+tfa <- read.table(tfa_file)
+meta_data <- NULL
+for (ix in order_celltype){
+ for (jx in order_age){
+ curr_sample <- paste(order_rep, jx, ix, sep='_')
+ curr_meta <- data.frame(row.names=curr_sample, CellType=ix, Age=jx, Rep=order_rep)
+ meta_data <- rbind(meta_data, curr_meta)
+ }
+}
+
+# diff <- setdiff(rownames(meta_data), colnames(tfa))
+meta_data <- meta_data[intersect(colnames(tfa), rownames(meta_data)), ]
+
+# rearrange tfa matrix
+tfa <- tfa[ ,rownames(meta_data)]
+# Z-scoring or Standardization
+z_tfa <- standardizeAndNormalizeCounts(tfa, meta_data = meta_data, celltype_var = "CellType", epsilon = NULL)
+tfa_z_across = z_tfa$z_across
+tfa_z_within = z_tfa$z_within
+
+# Create Annotation Color
+getPalette = colorRampPalette(brewer.pal(13, "Set1"))
+celltype_colors = getPalette(length(unique(meta_data[,1])))
+celltype_colors <- setNames(celltype_colors, c(unique(meta_data[,1])))
+treatment_colors <- c('Young'='grey66', 'Old'='grey33')
+
+meta_data[,2] <- factor()
+#Create annotation column
+colAnn_top <- HeatmapAnnotation(
+ `Cell Type` = meta_data[,1],
+ `Age` = meta_data[,2],
+ col = list('Cell Type' = celltype_colors,
+ 'Age' = treatment_colors),
+ show_annotation_name = FALSE
+ )
+
+across_center <- 6
+# within_center <- 4
+gc()
+if (is.null(file_k_clust_across)){
+ k_clust_across <- kmeans(tfa_z_across, centers=across_center,nstart=20, iter.max = 50)
+}else{
+ k_clust_across <- readRDS(file_k_clust_across)
+}
+
+# gc()
+# if (is.null(file_k_clust_within)){
+# k_clust_within <- kmeans(tfa_z_within, centers=within_center, nstart=20, iter.max = 50)
+# }else{
+# k_clust_within <- readRDS(file_k_clust_within)
+# }
+
+# # Cluster
+k_clust_across <- kmeans(tfa_z_across, centers=6,nstart=20, iter.max = 50)
+split_by <- factor(k_clust_across$cluster, levels = c(1:6))
+
+heat_col <- colorRamp2(c(-2, 0, 2), c('dodgerblue3','white','red')) ## Colors for heatmap
+pdf(file.path(dirOut, "htmap_zscore_across.pdf"), width = 4, height = 6, compress = T)
+ht1 <- Heatmap(tfa_z_across,
+ name='Z-score',
+ show_row_names = F,
+ row_split=split_by,
+ row_gap = unit(0, "mm"),
+ column_gap = unit(0, "mm"),
+ border = TRUE,
+ row_title = NULL,
+ row_dend_reorder = F,
+ use_raster = F,
+ col = heat_col,
+ # clustering settings
+ cluster_rows = TRUE, # allow hierarchical clustering
+ cluster_row_slices = TRUE, # reorder within each k-means cluster
+ cluster_columns = FALSE,
+ cluster_column_slices = F,
+ # column_order = colnames(tfa_z_across),
+
+ # row/column dendrograms
+ show_row_dend = FALSE, # show dendrogram within slices
+ show_column_dend = FALSE,
+
+ show_column_names = FALSE,
+ column_names_side = 'top',
+ # column_names_rot = 45,
+ column_title = NULL,
+ column_split = meta_data[,1],
+ top_annotation = colAnn_top
+ #bottom_annotation = colAnn_bottom
+ )
+draw(ht1)
+dev.off()
+
+
+# cluster_annot <- data.frame(across=k_clust_across$cluster,within=k_clust_within$cluster)
+# cluster_annot$cluster <- paste0(cluster_annot$across,cluster_annot$within)
+# cluster_annot <- cluster_annot[order(cluster_annot$cluster), ]
+# within_clust_order <- c(1,2,3,4)
+# across_clust_order <- c(2,3,5,4,6,1)
+# cluster_annot$within <- factor(cluster_annot$within, levels=within_clust_order, ordered=T)
+# cluster_annot$across <- factor(cluster_annot$across, levels=across_clust_order, ordered=T)
+
+
+# annot_cols_across <- c('1'='#cc0001','2'='#fb940b','3'='#ffff01','4'='#01cc00','5'='#2085ec','6'='#fe98bf','7'='#762ca7','8'='#ad7a5b',
+# '9'='grey50', '10'='turquoise4')
+
+# annot_cols_within <- c('1'='#badf55','2'='#35b1c9','3'='#b06dad','4'="#14A76C", '5' = 'grey90', '6' = '#e96060')
+
+# #reorder clusters
+# split <- factor(cluster_annot$cluster, levels=c(
+# 21,22,23,24,
+# 31,32,33,34,
+# 51,52,53,54,
+# 41,42,43,44,
+# 61,62,63,64,
+# 11,12,13,14
+# ))
+
+# cluster_annot$cluster <- factor(cluster_annot$cluster, levels=levels(split))
+# cluster_annot <- cluster_annot[order(cluster_annot$cluster),]
+# names(split) <- rownames(cluster_annot)
+
+# rowAnn_across_left1 <- HeatmapAnnotation(`Across` = cluster_annot[,'across'], #df = cluster_annot[,'across'],
+# col = list('Across'= annot_cols_across),
+# which = 'row',
+# simple_anno_size = unit(3, 'mm'),
+# #annotation_width = unit(c(1, 4), 'cm'),
+# #gap = unit(0, 'mm'),
+# show_annotation_name = F)
+
+
+# pdf(file.path(dirOut, "htmap_zscore_across1.pdf"), width = 4, height = 6, compress = T)
+# ht1 <- Heatmap(tfa_z_across[rownames(cluster_annot),],
+# name='Z-score',
+# show_row_names = F,
+# row_split=cluster_annot$across,
+# show_column_dend =F,
+# row_gap = unit(0, "mm"),
+# column_gap = unit(0, "mm"),
+# border = TRUE,
+# row_title = NULL,
+# row_dend_reorder = F,
+# use_raster = F,
+# col = heat_col,
+# column_names_gp = gpar(fontsize = 8, fontface = 2),
+# show_row_dend = FALSE,
+# cluster_rows = TRUE,
+# cluster_columns = F,
+# cluster_row_slices = F,
+# cluster_column_slices = F,
+# column_order = colnames(tfa_z_across),
+# show_column_names = FALSE,
+# column_names_side = 'top',
+# # column_labels = gt_render(column_label),
+# column_names_rot = 45,
+# column_title = NULL,
+# column_split = meta_data[,1],
+# left_annotation = rowAnn_across_left1,
+# # right_annotation = rowAnn_across_right1,
+# row_order = rownames(cluster_annot),
+# top_annotation = colAnn_top
+# #bottom_annotation = colAnn_bottom
+# )
+
+
+
+# PART C: TFA Visualization between data representation
+tfaFiles <- list(
+ "PB" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/TFA.txt",
+ "SC" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/SC/ATAC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5/Combined/TFA.txt",
+ "MC2" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/metaCells/MC2/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63_logNorm/Combined/TFA.txt",
+ "SEA" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/metaCells/SEACells/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/TFA.txt"
+)
+
+tfaMatList <- list()
+for (typeName in names(tfaFiles)) {
+ filePath <- tfaFiles[[typeName]]
+ tfaMatList[[typeName]] <- read.table(filePath, header = TRUE, sep = "\t", stringsAsFactors = FALSE)
+}
+
+numTFA <- length(tfaMatList)
+tfaNames <- names(tfaMatList)
+
+
+
+# -----------------------------
+# Feature Plot
+# -----------------------------
+objPath <- "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/annotation_scrna_final/obj_Tfh10_RNA_annotated.rds"
+tfa_file1 <- "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/SC/ATAC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5/Combined/TFA.txt"
+obj <- readRDS(objPath)
+tfa_mat1 <- read.table(tfa_file1, header = T)
+lineage_TFs <- c("Bcl6","Maf","Batf","Tox","Ascl2",
+ "Foxp3","Ikzf2","Rorc",
+ "Rorc","Stat3",
+ "Tbx21","Stat4","Eomes","Prdm1",
+ "Tcf7","Klf2","Lef1")
+
+# Filter lineage TFs present in your TFA matrix
+lineage_TFs_present <- intersect(lineage_TFs, rownames(tfa_mat))
+
+# Add each TFA as metadata
+for(tf in lineage_TFs_present){
+ obj[[tf]] <- tfa_mat[tf, colnames(obj)]
+}
+
+# Example: FeaturePlot for all lineage TFs
+FeaturePlot(seurat_obj,
+ features = lineage_TFs_present,
+ cols = c("lightgrey","red"),
+ reduction = "umap") # or "tsne" if you use tSNE
+
diff --git a/evaluation/R/evaluateNetUtils.R b/evaluation/R/evaluateNetUtils.R
new file mode 100755
index 0000000..bea8551
--- /dev/null
+++ b/evaluation/R/evaluateNetUtils.R
@@ -0,0 +1,644 @@
+suppressPackageStartupMessages({
+ library(dplyr)
+ library(tidyr)
+ library(ggplot2)
+ library(ComplexHeatmap)
+ library(pheatmap)
+ # library(VennDiagram)
+ # library(ggVennDiagram)
+ library(grid)
+ library(reshape2)
+ library(RColorBrewer)
+
+})
+
+# ==========================
+# PART A: EDGE SET CONSTRUCTION
+# ==========================
+
+# -------- 1. Build Edge Sets -----------
+getEdgeSets <- function(netDataList, tfCol="TF", targetCol="Gene", rankCol="signedQuantile",
+ mode = c("topNpercent", "topN"), N = NULL) {
+ #' getEdgeSets
+ #'
+ #' Extract edge sets from a list of networks, optionally selecting top-ranked edges.
+ #'
+ #' @param netDataList A named list of data.frames. Each data.frame represents a network with at least
+ #' the columns for TF, target gene, and ranking.
+ #' @param tfCol Character. Name of the column containing transcription factors (default "TF").
+ #' @param targetCol Character. Name of the column containing target genes (default "Gene").
+ #' @param rankCol Character. Name of the column containing ranking scores (default "signedQuantile").
+ #' @param mode Character. Method to select edges when N is specified. Choices are:
+ #' "topNpercent" (default) - select top N percent of edges based on rankCol,
+ #' "topN" - select top N edges.
+ #' Ignored if N is NULL.
+ #' @param N Numeric. Number of edges (for "topN") or percent (for "topNpercent") to select.
+ #' If NULL (default), all edges are returned and mode is ignored.
+ #'
+ #' @return A named list of character vectors. Each element contains edges in the form "TF~Gene".
+ #'
+ mode <- match.arg(mode)
+
+ edgeSets <- lapply(seq_along(netDataList), function(i) {
+ df <- netDataList[[i]] %>%
+ filter(.data[[rankCol]] != 0) %>%
+ arrange(desc(.data[[rankCol]]))
+
+ if(!rankCol %in% colnames(df)) stop(paste("Column", rankCol, "not found in network", names(netDataList)[i]))
+
+ df <- df %>%
+ mutate(edge = paste(.data[[tfCol]], .data[[targetCol]], sep="~"))
+ # If N is NULL, return all edges
+ if (is.null(N)) {
+ return(df$edge)
+ }
+
+ if(mode == "topN") {
+ df <- df %>% slice_head(n = N)
+ } else { # topNpercent
+ cutoff <- quantile(df[[rankCol]], 1 - N/100, na.rm = TRUE)
+ df <- df %>% filter(abs(.data[[rankCol]]) >= cutoff)
+ }
+
+ df$edge
+ })
+
+ names(edgeSets) <- names(netDataList)
+ return(edgeSets)
+}
+
+# ======================================================
+# PART B: Pairwise Overlap / Set Metrics & Correlation
+# ======================================================
+
+# - 1. Global Network Jaccard -----------
+computeJaccardMatrix <- function(edgeSets) {
+ networkNames <- names(edgeSets)
+ numNetworks <- length(edgeSets)
+
+ pairInt <- matrix(NA, nrow=numNetworks, ncol=numNetworks, dimnames=list(networkNames, networkNames))
+ pairJac <- matrix(NA, nrow=numNetworks, ncol=numNetworks, dimnames=list(networkNames, networkNames))
+
+ if (numNetworks > 1) {
+ for (i in 1:(numNetworks-1)) {
+ for (j in (i+1):numNetworks) {
+ shared <- intersect(edgeSets[[i]], edgeSets[[j]])
+ un <- union(edgeSets[[i]], edgeSets[[j]])
+ pairInt[i,j] <- length(shared)
+ pairJac[i,j] <- pairJac[j,i] <- if (length(un)) length(shared)/length(un) else NA
+ }
+ }
+ }
+
+ diag(pairInt) <- vapply(edgeSets, length, integer(1))
+ diag(pairJac) <- 1
+
+ return(list(pairJac = pairJac, pairInt = pairInt))
+}
+
+
+# ==============================================================================
+# - 2. Compute Spearman Correlation between pairs of ranks or weights
+# ==============================================================================
+#'
+#' @param netDataList Named list of data.frames. Each data.frame must contain TF, target gene, and score columns.
+#' @param edgeSets Optional named list of character vectors (from getEdgeSets). If provided, only these edges are used.
+#' @param tfCol Character. Name of TF column (default "TF").
+#' @param targetCol Character. Name of target gene column (default "Gene").
+#' @param rankCol Character. Name of score column (default "signedQuantile").
+#'
+#' @return A tibble with columns: Net1, Net2, Spearman correlation.
+#'
+#' @examples
+#' # Spearman on all edges
+#' computeSpearman(netDataList)
+#'
+#' # Spearman on top edges only
+#' edges_top10pct <- getEdgeSets(netDataList, N=10)
+#' computeSpearman(netDataList, edgeSets = edges_top10pct)
+computeSpearman <- function(netDataList,
+ edgeSets = NULL,
+ ref = NULL,
+ tfCol="TF",
+ targetCol="Gene",
+ rankCol="signedQuantile") {
+
+ prepForCor <- function(df, edges = NULL) {
+ df <- df %>%
+ filter(.data[[rankCol]] != 0) %>%
+ mutate(edge = paste(.data[[tfCol]], .data[[targetCol]], sep="~")) %>%
+ dplyr::select(edge, score = .data[[rankCol]])
+
+ if(!is.null(edges)) {
+ df <- df %>% filter(edge %in% edges)
+ }
+
+ return(df)
+ }
+
+ # Prepare each network
+ nets <- lapply(seq_along(netDataList), function(i) {
+ edges_to_use <- if(!is.null(edgeSets)) edgeSets[[names(netDataList)[i]]] else NULL
+ prepForCor(netDataList[[i]], edges = edges_to_use)
+ })
+
+ netNames <- names(netDataList)
+ names(nets) <- netNames
+
+ # --- Compute pairwise Spearman correlations
+ # Build comparison pairs
+ if(!is.null(ref)){
+ if(!ref %in% netNames){
+ stop("Reference not found in edgeSets")
+ }
+ combs <- lapply(setdiff(netNames, ref), function(x)
+ c(ref, x))
+ } else {
+ combs <- combn(netNames, 2, simplify = FALSE)
+ }
+
+ map_dfr(combs, function(p) {
+ a <- nets[[p[1]]]
+ b <- nets[[p[2]]]
+
+ joined <- inner_join(a, b, by = "edge", suffix = c(".1", ".2"))
+ rho <- if(nrow(joined) > 1) {
+ cor(joined$score.1, joined$score.2, method = "spearman", use = "complete.obs")
+ } else {NA}
+
+ tibble(
+ Net1 = p[1],
+ Net2 = p[2],
+ Spearman = rho
+ )
+ })
+}
+
+
+# ==== 3.Compute Edge overlaps ====
+computeOverlapStats <- function(edgeSets, ref = NULL){
+ nets <- names(edgeSets)
+ # Build comparison pairs
+ if(!is.null(ref)){
+ if(!ref %in% nets){
+ stop("Reference not found in edgeSets")
+ }
+ combs <- lapply(setdiff(nets, ref), function(x)
+ c(ref, x))
+ } else {
+ combs <- combn(nets, 2, simplify = FALSE)
+ }
+ # combs <- combn(names(edgeSets), 2, simplify=FALSE)
+
+ map_dfr(combs, function(p){
+ A <- edgeSets[[p[1]]]
+ B <- edgeSets[[p[2]]]
+
+ nA <- length(A)
+ nB <- length(B)
+
+ shared <- length(intersect(A,B))
+ un <- length(union(A,B))
+
+ # Dice Coefficient (or Sorensen-Dice Index)
+ dice <- ifelse(
+ nA + nB > 0,
+ 2 * shared / (nA + nB),
+ NA
+ )
+
+ # What fraction of the smaller set is shared with the larger set?
+ # Szymkiewicz–Simpson coefficient
+ fracOverlap =
+ ifelse(min(nA,nB)>0,
+ shared / min(nA,nB),
+ NA)
+ tibble(
+ Net1 = p[1],
+ Net2 = p[2],
+ pairID = paste0(p[1], "~", p[2]),
+ size1 = nA,
+ size2 = nB,
+ Intersection = shared,
+ frac1in2 = ifelse(nA > 0, shared / nA, NA),
+ frac2in1 = ifelse(nB > 0, shared / nB, NA),
+ Jaccard = ifelse(un> 0, shared/un, NA),
+ Dice = dice,
+ FractionOverlap = fracOverlap
+
+ )
+ })
+}
+
+# ==== 3.Compute TF overlaps ====
+computeTFoverlap <- function(netDataList, ref = NULL, tfCol = "TF") {
+ # Build TF sets
+ tfSets <- lapply(netDataList, function(df) unique(df[[tfCol]]))
+
+ nets <- names(tfSets)
+
+ # Build comparison pairs
+ if(!is.null(ref)){
+ if(!ref %in% nets){
+ stop("Reference not found in netDataList")
+ }
+ combs <- lapply(setdiff(nets, ref), function(x) c(ref, x))
+ } else {
+ combs <- combn(nets, 2, simplify = FALSE)
+ }
+
+ # Compute pairwise overlaps
+ map_dfr(combs, function(p){
+ A <- tfSets[[p[1]]]
+ B <- tfSets[[p[2]]]
+ nA <- length(A)
+ nB <- length(B)
+ ov <- intersect(A,B)
+ shared <- length(ov)
+ un <- length(union(A,B))
+
+ # dice <- ifelse(nA + nB > 0, 2 * shared / (nA + nB), NA)
+ # fracOverlap <- ifelse(min(nA,nB) > 0, shared / min(nA,nB), NA)
+
+ tibble(
+ Net1 = p[1],
+ Net2 = p[2],
+ pairID = paste0(p[1], "~", p[2]),
+ size1 = nA,
+ size2 = nB,
+ Overlaps = paste(ov, collapse = ","),
+ nIntersection = shared,
+ frac1in2 = ifelse(nA > 0, shared / nA, NA),
+ frac2in1 = ifelse(nB > 0, shared / nB, NA)
+ # Jaccard = ifelse(un > 0, shared/un, NA),
+ # Dice = dice,
+ # FractionOverlap = fracOverlap
+ )
+ })
+}
+
+# ====================================================
+# PART C: TF-Level / Aggregated Metrics
+# ====================================================
+
+# ===== 1. TF-centric Jaccard ====
+#' For each TF, how consistent are its predicted targets across networks?
+#' Biological networks are TF-driven, so checking target consistency per TF
+#' Jaccard penalizes sets with different sizes. It measures proportion of shared elements out of all unique elements
+#'
+computeTFJaccard <- function(netDataList, tfList=NULL){
+ networkNames <- names(netDataList)
+ numNetworks <- length(netDataList)
+
+ allTFs <- unique(unlist(lapply(netDataList, function(df) df$TF)))
+ tfs <- if(!is.null(tfList)) intersect(allTFs, tfList) else allTFs
+
+ pairNames <- combn(networkNames,2,function(x) paste(x, collapse="~"))
+ tfMatrix <- matrix(NA, nrow=length(tfs), ncol=length(pairNames), dimnames=list(tfs, pairNames))
+
+ for(tf in tfs){
+ tfTargets <- lapply(netDataList, function(df) df$Gene[df$TF == tf])
+ for(k in seq_along(pairNames)){
+ pair <- combn(seq_len(numNetworks),2)[,k]
+ shared <- intersect(tfTargets[[pair[1]]], tfTargets[[pair[2]]])
+ un <- union(tfTargets[[pair[1]]], tfTargets[[pair[2]]])
+ tfMatrix[tf,k] <- if(length(un)) length(shared) /length(un) else NA
+ }
+ }
+ return(tfMatrix)
+}
+
+# plotTFHeatmap <- function(tfMatrix, fileOut, fontsize_number=7, fontsize=9, fig_width = 6, fig_height = 6){
+# heatCols <- colorRampPalette(RColorBrewer::brewer.pal(9,"YlOrRd"))(100)
+# pdf(fileOut, width = fig_width, height = fig_height)
+# pheatmap(tfMatrix,
+# display_numbers=TRUE,
+# number_color="black",
+# number_format="%.2f",
+# fontsize_number=fontsize_number,
+# fontsize=fontsize,
+# cluster_rows=FALSE,
+# cluster_cols=FALSE,
+# show_rownames = TRUE,
+# show_colnames = TRUE,
+# legend = FALSE , # removes grid lines
+# color=heatCols,
+# main=""
+# )
+# dev.off()
+# }
+
+# ==== 2. Aggregated TF Jaccard per Network Pair =====
+computeAggregatedPairJaccard <- function(tfMatrix, fun=median){
+ # tfMatrix: rows = TFs, cols = network pairs
+ aggVec <- apply(tfMatrix, 2, fun, na.rm=TRUE)
+
+ # Convert to symmetric matrix
+ pairNames <- colnames(tfMatrix)
+ nets <- unique(unlist(strsplit(pairNames, "~")))
+ numNetworks <- length(nets)
+ aggMat <- matrix(NA, nrow=numNetworks, ncol=numNetworks, dimnames=list(nets,nets))
+
+ for(k in seq_along(pairNames)){
+ pair <- strsplit(pairNames[k], "~")[[1]]
+ aggMat[pair[1], pair[2]] <- aggVec[k]
+ aggMat[pair[2], pair[1]] <- aggVec[k]
+ }
+ diag(aggMat) <- 1
+ return(aggMat)
+}
+
+
+#' ==== 3./ Compute and plot hub TF set overlap across networks ====
+#'
+#' This function identifies the top N hub transcription factors (TFs) in each
+#' network, computes pairwise similarity between hub sets across networks using
+#' either the Jaccard index or the overlap coefficient, and plots a heatmap of
+#' the similarity matrix.
+#'
+#' @param netDataList List of data frames, one per network, each containing edges.
+#' @param tfCol Column name (string) indicating which column contains TF identifiers.
+#' @param networkNames Character vector of network names, matching the order of netDataList.
+#' @param dirOut Output directory for saving the heatmap PDF.
+#' @param topN Integer. Number of top hubs (by edge count) to include per network. Default = 50.
+#' @param metric Similarity metric. One of "jaccard" or "overlap". Default = "jaccard".
+#' @param heatCols Color palette for the heatmap. Default = 100-color YlOrRd.
+#' @param fontsize Base font size for heatmap text. Default = 9.
+#' @param fontsize_number Font size for numbers displayed in cells. Default = 7.
+#' #' @param fig_width Width of the PDF figure in inches. Default = 6.
+#' @param fig_height Height of the PDF figure in inches. Default = 6.
+
+#'
+#' @return Invisibly returns the similarity matrix used to generate the heatmap.
+#'
+#' @examples
+#' hubSim <- computeHubOverlapHeatmap(
+#' netDataList = netDataList,
+#' tfCol = "TF",
+#' networkNames = c("Net1", "Net2", "Net3"),
+#' dirOut = "results/",
+#' topN = 50,
+#' metric = "jaccard",
+#' fig_width = 8,
+#' fig_height = 8
+#' )
+#'
+#' @export
+computeHubOverlapHeatmap <- function(netDataList, tfCol, networkNames,
+ dirOut, topN = 50,
+ metric = c("jaccard", "overlap"),
+ heatCols = colorRampPalette(RColorBrewer::brewer.pal(9, "YlOrRd"))(100),
+ fontsize = 9, fontsize_number = 7, fig_width = 6, fig_height = 6) {
+ # --- Argument check
+ metric <- match.arg(metric)
+ numNetworks <- length(netDataList)
+
+ # --- 1. Count edges per TF and rank hubs
+ hubRankList <- lapply(netDataList, function(df) {
+ tfCounts <- table(df[[tfCol]])
+ sort(tfCounts, decreasing = TRUE)
+ })
+
+ # --- 2. Take top N hubs
+ topHubs <- lapply(hubRankList, function(x) head(names(x), topN))
+
+ # --- 3. Initialize similarity matrix
+ hubSim <- matrix(NA, numNetworks, numNetworks,
+ dimnames = list(networkNames, networkNames))
+
+ # --- 4. Compute similarity
+ for (i in 1:(numNetworks-1)) {
+ for (j in (i+1):numNetworks) {
+ inter <- intersect(topHubs[[i]], topHubs[[j]])
+
+ if(metric == "jaccard") {
+ union <- union(topHubs[[i]], topHubs[[j]])
+ hubSim[i,j] <- hubSim[j,i] <- length(inter) / length(union)
+ } else if(metric == "overlap") {
+ minSize <- min(length(topHubs[[i]]), length(topHubs[[j]]))
+ hubSim[i,j] <- hubSim[j,i] <- length(inter) / minSize
+ }
+ }
+ }
+ diag(hubSim) <- 1
+
+ # --- 5. Plot heatmap
+ pdf(file.path(dirOut, paste0("HubTF_Overlap_", metric, "_top", topN, ".pdf")), width = fig_width, height = fig_height)
+ pheatmap(hubSim,
+ display_numbers = TRUE,
+ number_color = "black",
+ number_format = "%.2f",
+ fontsize_number = fontsize_number,
+ fontsize = fontsize,
+ cluster_rows = FALSE,
+ cluster_cols = FALSE,
+ show_rownames = TRUE,
+ show_colnames = TRUE,
+ legend = FALSE,
+ color = heatCols,
+ main = "")
+ dev.off()
+
+ invisible(hubSim)
+}
+
+# ==================================================
+# - Plotting Ulis
+# ==================================================
+
+
+# ==== Edge Sharing ====
+plotEdgeSharing <- function(edgeSets, fileSuffix, outDir=".") {
+ library(ggVennDiagram)
+ library(ComplexUpset)
+ library(ComplexHeatmap)
+
+ numNetworks <- length(edgeSets)
+ dir.create(outDir, showWarnings = FALSE, recursive = TRUE)
+
+ if(numNetworks <= 3){
+ # --- Venn diagram
+ pdf(file.path(outDir, paste0("Venn2_", fileSuffix, ".pdf")), width=6, height=6)
+ ggVennDiagram(edgeSets, label_alpha = 0.6) +
+ scale_fill_gradient(low="grey90", high="red") +
+ theme(legend.position="none")
+ dev.off()
+
+ } else {
+ # --- UpSet plots for >2 networks
+ # Convert edgeSets to a presence/absence matrix
+ allEdges <- unique(unlist(lapply(edgeSets, function(df) paste(df$TF, df$Gene, sep="~"))))
+ incMat <- sapply(edgeSets, function(df) as.integer(allEdges %in% paste(df$TF, df$Gene, sep="~")))
+ colnames(incMat) <- names(edgeSets)
+ rownames(incMat) <- allEdges
+
+ # Intersect mode
+ mObjIntersect <- make_comb_mat(incMat, mode="intersect")
+ pdf(file.path(outDir, "UpSet_Intersect.pdf"), width=8.5, height=4)
+ draw(UpSet(mObjIntersect,
+ set_order = names(edgeSets),
+ top_annotation = upset_top_annotation(mObjIntersect, annotation_name_rot=90, axis=FALSE,
+ add_numbers=TRUE, numbers_rot=90,
+ gp=gpar(col=comb_degree(mObjIntersect), fontsize=6), height=unit(4,"cm")),
+ right_annotation = upset_right_annotation(mObjIntersect, gp=gpar(fill="black", fontsize=6),
+ width=unit(4,"cm"), show_annotation_name=FALSE, add_numbers=TRUE)))
+ dev.off()
+
+ # Distinct mode
+ mObjDistinct <- make_comb_mat(incMat, mode="distinct")
+ pdf(file.path(outDir, paste0("UpSet_distinct_", fileSuffix, ".pdf")), width=8.5, height=4)
+ draw(UpSet(mObjDistinct,
+ set_order = names(edgeSets),
+ top_annotation = upset_top_annotation(mObjDistinct, annotation_name_rot=90, axis=FALSE,
+ add_numbers=TRUE, numbers_rot=90,
+ gp=gpar(col=comb_degree(mObjDistinct), fontsize=6), height=unit(4,"cm")),
+ right_annotation = upset_right_annotation(mObjDistinct, gp=gpar(fill="black", fontsize=6),
+ width=unit(4,"cm"), show_annotation_name=FALSE, add_numbers=TRUE)))
+ dev.off()
+ }
+
+ message("Edge sharing plots generated.")
+}
+
+
+plotSimilarityHeatmap <- function(simMat, fileOut="sim_heatmap.pdf",
+ fontsize_number=7, fontsize=9, fig_width=6, fig_height=6) {
+ simMat[lower.tri(simMat)] <- t(simMat)[lower.tri(simMat)]
+ heatCols <- colorRampPalette(RColorBrewer::brewer.pal(9,"YlOrRd"))(100)
+ pdf(fileOut, height = fig_height, width = fig_width)
+ pheatmap::pheatmap(simMat,
+ display_numbers=TRUE,
+ number_color="black",
+ number_format="%.2f",
+ fontsize_number=fontsize_number,
+ fontsize=fontsize,
+ cluster_rows=FALSE,
+ cluster_cols=FALSE,
+ show_rownames = TRUE,
+ show_colnames = TRUE,
+ legend = FALSE ,
+ color=heatCols,
+ main=""
+ )
+ dev.off()
+}
+
+plotAggregatedJaccard <- function(aggMat, fileOut, fontsize_number=7, fontsize=9, fig_width = 6, fig_height = 6){
+ heatCols <- colorRampPalette(RColorBrewer::brewer.pal(9,"YlOrRd"))(100)
+ pdf(fileOut, width = fig_width, height = fig_height)
+ pheatmap(aggMat,
+ display_numbers=TRUE,
+ number_color="black",
+ number_format="%.2f",
+ fontsize_number=fontsize_number,
+ fontsize=fontsize,
+ cluster_rows=FALSE,
+ cluster_cols=FALSE,
+ show_rownames = TRUE,
+ show_colnames = TRUE,
+ legend = FALSE , # removes grid lines
+ color=heatCols,
+ main="Aggregated TF Jaccard per Network Pair")
+ dev.off()
+}
+
+plotMetricTrend <- function(df, xCol, yCol, groupCol = "pairID", colorMap = NULL,
+ xLabel = NULL, yLabel = NULL, yLimits = NULL, xBreaks = NULL,
+ outFile = NULL, width = 6, height = 3, dpi = 600){
+ p <- ggplot(df,
+ aes_string(x = xCol, y = yCol, color = groupCol, group = groupCol)
+ ) +
+ geom_line(linewidth = 0.5) +
+ geom_point(size = 1) +
+ theme_bw(base_family = "Helvetica") +
+ theme(
+ axis.text.x = element_text(size = 7, color = "black"),
+ axis.text.y = element_text(size = 7, color = "black"),
+ axis.title = element_text(size = 9, color = "black"),
+ legend.title = element_blank(),
+ legend.text = element_text(size = 9, color = "black"),
+ legend.position = "right",
+ legend.box.just = "left",
+ panel.grid.major = element_line(color = "grey80", linewidth = 0.3),
+ panel.grid.minor = element_line(color = "grey90", linewidth = 0.1)
+ )
+
+ # Optional scales
+ if(!is.null(colorMap)){
+ p <- p + scale_color_manual(values = colorMap)
+ }
+
+ if(!is.null(xBreaks)){
+ p <- p + scale_x_continuous(breaks = xBreaks)
+ }
+
+ if(!is.null(yLimits)){
+ p <- p + scale_y_continuous(limits = yLimits)
+ } else{
+ dataMin <- min(df[[yCol]], na.rm = TRUE)
+ dataMax <- max(df[[yCol]], na.rm = TRUE)
+ yMin <- if (dataMin >= 0) 0 else dataMin
+ yMax <- max(dataMax, 1)
+ yLimits <- c(yMin, yMax)
+ p <- p + scale_y_continuous(limits = yLimits)
+ }
+ # Labels
+ p <- p + labs( x = xLabel, y = yLabel )
+ # Save if requested
+ if(!is.null(outFile)) ggsave(outFile, plot = p, width = width, height = height, dpi = dpi)
+
+ return(p)
+}
+
+# ==========================================
+# High-Level Pipeline / API (Top Level)
+# ==========================================
+computeTopNMetrics <- function(netDataList,
+ ref = NULL,
+ topNs = NULL,
+ tfCol="TF",
+ targetCol="Target",
+ rankCol="Weight",
+ mode = c("topNpercent", "topN")) {
+ #' Compute Top-N overlap and correlation metrics for multiple networks
+ #'
+ #' @param netDataList Named list of network data.frames
+ #' @param topNs Numeric vector of percentages (Top N%) or number to evaluate
+ #' @param tfCol Character, column name for transcription factor
+ #' @param targetCol Character, column name for target gene
+ #' @param rankCol Character, column name for ranking score
+ #'
+ #' @return data.frame with columns: pairID, Jaccard, FractionOverlap, Spearman, topNs
+ mode = match.arg(mode)
+ if (is.null(topNs)) topNs <- "all"
+ topNResults <- lapply(topNs, function(k) {
+ cat("Top-N:", k, "\n")
+
+ if (k == "all"){
+ edges <- getEdgeSets(netDataList, tfCol = tfCol,
+ targetCol = targetCol, rankCol = rankCol)
+ } else{
+ # Extract Top-N% edges
+ edges <- getEdgeSets(netDataList, tfCol = tfCol,
+ targetCol = targetCol, rankCol = rankCol,
+ mode = mode, N = k)
+ }
+ # Overlap statistics
+ stats <- computeOverlapStats(edges, ref)
+ stats$Npct <- if(k == "all") NA else k
+
+ # Spearman correlation
+ corRes <- computeSpearman(netDataList, edgeSets = edges, ref = ref, tfCol = tfCol, targetCol = targetCol, rankCol = rankCol)
+ # Merge Spearman with stats
+ stats <- stats %>%
+ left_join(corRes %>%
+ mutate(pairID = paste(Net1, Net2, sep="~")) %>%
+ dplyr::select(pairID, Spearman),
+ by = "pairID")
+
+ stats
+ })
+
+ bind_rows(topNResults)
+}
+
diff --git a/evaluation/R/evaluateNetworks.R b/evaluation/R/evaluateNetworks.R
new file mode 100755
index 0000000..ee2671f
--- /dev/null
+++ b/evaluation/R/evaluateNetworks.R
@@ -0,0 +1,311 @@
+rm(list = ls())
+suppressPackageStartupMessages({
+ library(dplyr)
+ library(tidyr)
+ library(ggplot2)
+ library(ComplexHeatmap)
+ library(pheatmap)
+ # library(VennDiagram)
+ library(ggVennDiagram)
+ library(grid)
+ library(reshape2)
+})
+
+source("/data/miraldiNB/Michael/Scripts/VennDiagram.R")
+
+# dirOut <- "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/Bulk/5kbTSS/newPipe/spikeIgnored/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63"
+dirOut <- "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/networkEval/"
+dir.create(dirOut, recursive = T, showWarnings = F)
+
+# Combined/combined_max.tsv
+# TFmRNA/edges_subset.txt
+# netFiles <- list(
+# "TFA" = "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/SC/SCENIC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5_eRegulon/TFA/edges_subset.tsv",
+# "TFmRNA" = "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/SC/SCENIC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5_eRegulon/TFmRNA/edges_subset.tsv",
+# "Combined" = "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/SC/SCENIC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5_eRegulon/Combined/combined_max.tsv"
+# )
+
+netFiles <- list(
+ "TFA" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.tsv",
+ "TFmRNA" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.tsv",
+ "Combined" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/combined_max.tsv"
+ )
+
+# priorFile = "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/MotifScan5kbTSS_b_sp.tsv"
+priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10_sp.tsv"
+# k <- read.table("/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/SCENICp/SCENICp_RE2Glinks_FIMO_5Kb_derived_b.tsv")
+# k$Target <- rownames(k)
+# b <- melt(k, id.vars = "Target", variable.names = "TF", value.name = "Weigths")
+# colnames(b)[2:3] <- c("TF", "Weights")
+# b <- b[, c("TF", "Target", "Weights")]
+# bb <- b[b$Weights != 0, ]
+# write.table(bb, "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/SCENICp/SCENICp_RE2Glinks_FIMO_5Kb_derived_sp.tsv", row.names = F, sep ="\t", quote = F)
+
+tfCol <- "TF" # specify TF column name here
+targetCol <- "Gene" # specify Target column name here
+priorTfCol <- "TF" # TF column name in prior file
+priorTargetCol <- "Target" # Target column name in prior file
+priorWgtCol <-"Weight"
+nSelect <- 10
+stabCol <- "Stability"
+compareNets <- TRUE
+# ========================================================================
+# ------- Read Input files
+# ========================================================================
+
+# ----- Read Prior
+priorData <- read.table(priorFile, header = TRUE, sep = "\t", stringsAsFactors = FALSE)
+# priorData <- priorData[priorData[[priorWgtCol]] != 0, ]
+priorPairs <- unique(paste(priorData[[priorTfCol]], priorData[[priorTargetCol]], sep = "_"))
+
+# --- Preload all network data into a list ---
+netDataList <- list()
+for (typeName in names(netFiles)) {
+ filePath <- netFiles[[typeName]]
+ netDataList[[typeName]] <- read.table(filePath, header = TRUE, sep = "\t", stringsAsFactors = FALSE)
+}
+
+numNetworks <- length(netDataList)
+networkNames <- names(netDataList)
+# ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
+# ---- PART ONE: Make a table of number of network ppts - uniqueTFs, uniqueTargets, Total Interactions, and % supported by prior
+# ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
+
+# process each network
+summaryList <- list()
+for (typeName in networkNames) {
+ data <- netDataList[[typeName]]
+
+ uniqueTf <- length(unique(data[[tfCol]]))
+ uniqueTarget <- length(unique(data[[targetCol]]))
+
+ pairStrings <- unique(paste(data[[tfCol]], data[[targetCol]], sep = "_"))
+ totalInteractions <- length(pairStrings)
+
+ supportedCount <- length(intersect(pairStrings, priorPairs))
+ percentSupportedByPrior <- 100 * supportedCount / totalInteractions
+
+ summaryList[[typeName]] <- data.frame(type = typeName, uniqueTf = uniqueTf,
+ uniqueTarget = uniqueTarget, totalInteractions = totalInteractions,
+ percentSupportedByPrior = percentSupportedByPrior)
+}
+
+summaryTable <- do.call(rbind, summaryList)
+print(summaryTable)
+
+write.table(as.data.frame(summaryTable), file.path(dirOut, "summaryStatistics.tsv"), col.names = NA, row.names = TRUE, quote = FALSE, sep = "\t")
+
+
+# ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
+# ---- PART TWO: Compare Networks (can handle more than 2)
+# ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
+if (compareNets){
+
+ #--- Build edge sets
+ edgeSets <- lapply(netDataList, function(df) paste(df[[tfCol]], df[[targetCol]], sep="~"))
+ names(edgeSets) <- names(netDataList)
+ networkNames <- names(netDataList)
+ numNetworks <- length(netDataList)
+
+ #--- Edges shared between all networks
+ allShared <- Reduce(intersect, edgeSets)
+
+ #1. --- Pairwise Jaccard, Intersections, and Heatmap
+
+ if (numNetworks > 1) {
+ pairInt <- matrix(NA, numNetworks, numNetworks, dimnames=list(networkNames,networkNames))
+ pairJac <- matrix(NA, numNetworks, numNetworks, dimnames=list(networkNames,networkNames))
+
+ for (i in 1:(numNetworks-1)) {
+ for (j in (i+1):numNetworks) {
+ shared <- intersect(edgeSets[[i]], edgeSets[[j]])
+ un <- union(edgeSets[[i]], edgeSets[[j]])
+ pairInt[i,j] <- length(shared)
+ pairJac[i,j] <- pairJac[j,i] <- if (length(un)) length(shared)/length(un) else NA
+ }
+ }
+ diag(pairInt) <- vapply(edgeSets, length, integer(1))
+ diag(pairJac) <- 1
+
+ # heat-map of Jaccard ----------------------------------------
+ heatCols <- colorRampPalette(RColorBrewer::brewer.pal(9,"YlOrRd"))(100)
+ pdf(file.path(dirOut, "Jaccard.pdf"))
+ pheatmap(pairJac, display_numbers = TRUE, main = "Pairwise Jaccard Index", color = heatCols)
+ dev.off()
+
+ # -------- unique-edge counts ----------------------------------------- ##
+ uniqCountsDf <- data.frame(
+ network = names(edgeSets),
+ numUniqueEdges = sapply(seq_along(edgeSets), function(k){
+ curr <- edgeSets[[k]]
+ others <- unique(unlist(edgeSets[-k]))
+ sum(!curr %in% others)
+ }),
+ row.names = NULL,
+ stringsAsFactors = FALSE
+ )
+
+ # 3. ------- Save Pairwise and uniqueDF counts
+ pairDf <- melt(pairInt, varnames = c("network1", "network2"), value.name = "numIntersection", na.rm = TRUE)
+ write.table(pairDf, file.path(dirOut, "num_Pairwise_Intersection.tsv"), quote = F, col.names = NA, row.names=TRUE, sep = "\t")
+
+ }
+
+ # 2.------ Venn or UpSet Plot
+ if (numNetworks >= 2) {
+ vennInput <- edgeSets
+ ggVennDiagram(vennInput, label_alpha = 0.6) + scale_fill_gradient(low="grey90", high = "red")
+ ggsave(file.path(dirOut, "Venn2.pdf"))
+ # venn.plot <- venn.diagram(
+ # vennInput,
+ # category.names = names(vennInput),
+ # filename = NULL,
+ # output = TRUE,
+ # main = "Shared TF~Gene Pairs"
+ # )
+ # pdf(file.path(dirOut, "Venn2.pdf"))
+ # grid::grid.draw(venn.plot)
+ # dev.off()
+
+ } else if (numNetworks > 3) {
+ # allEdges <- unique(unlist(edgeSets))
+ # incMat <- sapply(edgeSets, function(edgeSet) as.integer(allEdges %in% edgeSet))
+ # colnames(incMat) <- names(edgeSets)
+ # rownames(incMat) <- allEdges
+ upSetData <- edgeSets
+
+ plotUpSet <- function(setsData, Mode){
+ mObj <- make_comb_mat(setsData, mode = Mode)
+ ht <- UpSet(mObj, set_order = names(setsData),
+ top_annotation = upset_top_annotation(mObj, annotation_name_rot = 90, axis = FALSE,
+ add_numbers = TRUE, numbers_rot = 90,
+ gp = gpar(col = comb_degree(mObj), fontsize = 6), height = unit(4, "cm"),
+ axis_param = list(side = "left")),
+ right_annotation = upset_right_annotation(mObj, axis_param = list(side = "bottom"), #labels = FALSE,labels_rot = 0
+ gp = gpar(fill = "black", fontsize = 6), width = unit(4, "cm"), show_annotation_name = FALSE,
+ add_numbers = TRUE, axis = FALSE
+ )
+ )
+ return(ht)
+ }
+
+ pdf(file.path(dirOut, "UpSet_distinct.pdf"), width = 8.5, height = 4)
+ htDistinct <- plotUpSet(upSetData, Mode = "distinct")
+ draw(htDistinct)
+ dev.off()
+
+ pdf(file.path(dirOut, "UpSet_Intersect.pdf"), width = 8.5, height = 4)
+ htIntersect <- plotUpSet(upSetData, Mode = "intersect")
+ draw(htIntersect)
+ dev.off()
+ }
+}
+
+
+# ────────────────────────────────────────────────────────────────────────────────────────────
+# PART THREE: Histogram distributions of:
+ # A: Targets per TF: # times each gene is targeted (number of TFs regulating it).
+ # B: TFs per Target: # times each TF is a regulator (number of genes it controls)
+ # C: Box-Plot of Stability of the top N low and high in degree genes
+
+ #NOTE: Each Figure is saved in the same path as the network being evaluated
+# ────────────────────────────────────────────────────────────────────────────────────────────
+# Generate both plots for each network
+for (typeName in names(netDataList)) {
+ data <- netDataList[[typeName]]
+ basePath <- dirname(netFiles[[typeName]])
+
+ # Plot 1: # Targets per TF
+ tfTargetCounts <- table(data[[tfCol]])
+ dfTF <- data.frame(TF = names(tfTargetCounts), targetCount = as.integer(tfTargetCounts))
+
+ p1 <- ggplot(dfTF, aes(x = targetCount)) +
+ geom_histogram(binwidth = 1, boundary = 0.5, fill = "#0072B2", color = "black", alpha = 0.8) +
+ # scale_x_continuous(breaks = seq(0, max(dfTF$targetCount), 1)) +
+ labs(title = paste("Distribution of # Targets per TF -", typeName),
+ x = "# Targets", y = "# TFs") +
+ theme_bw(base_family = "Helvetica") +
+ theme(
+ panel.grid.major.y = element_line(color = "grey80", linewidth = 0.3),
+ panel.grid.minor = element_blank(),
+ # axis.line = element_line(color = "black", linewidth = 0.4),
+ axis.text.x = element_text(size = 7),
+ axis.text.y = element_text(size = 7),
+ axis.title = element_text(size = 9),
+ plot.margin = margin(5, 5, 5, 5),
+ # panel.background = element_rect("black", fill = NA)
+ )
+
+ ggsave(file.path(basePath, paste0("TargetCountPerTF_", typeName, ".pdf")),
+ plot = p1, width = 3, height = 3)
+
+ # Plot 2: # TFs per Target
+ targetTFCounts <- table(data[[targetCol]])
+ dfTarget<- data.frame(Target = names(targetTFCounts), tfCount = as.integer(targetTFCounts))
+ dfTargetSorted <- dfTarget[order(dfTarget$tfCount), ]
+
+ p2 <- ggplot(dfTarget, aes(x = tfCount)) +
+ geom_histogram(binwidth = 1, boundary = 0.5, fill = "blue", color = "black", alpha = 0.7) +
+ # scale_x_continuous(breaks = seq(0, max(dfTarget$tfCount), 1)) +
+ labs(title = paste("Distribution of # TFs per Target -", typeName),
+ x = "# TFs", y = "# Targets") +
+ theme_bw(base_family = "Helvetica") +
+ theme(
+ panel.grid.major.y = element_line(color = "grey80", linewidth = 0.3),
+ panel.grid.minor = element_blank(),
+ # axis.line = element_line(color = "black", linewidth = 0.4),
+ axis.text.x = element_text(size = 7),
+ axis.text.y = element_text(size = 7),
+ axis.title = element_text(size = 9),
+ plot.margin = margin(5, 5, 5, 5),
+ # panel.background = element_rect("black", fill = NA)
+ )
+
+ ggsave(file.path(basePath, paste0("TFCountPerTarget_", typeName, ".pdf")),
+ plot = p2, width = 3, height = 3)
+
+ # Select top N low and high TF targets
+ lowTargs <- dfTargetSorted$Target[1:nSelect]
+ highTargs <- tail(dfTargetSorted, nSelect)$Target
+
+ stabilityDF <- data[data[[targetCol]] %in% c(lowTargs, highTargs), c(targetCol, stabCol)]
+ # Add Group column based on whether the target is in lowTargs or highTargs
+ stabilityDF$Group <- ifelse(stabilityDF[[targetCol]] %in% lowTargs,
+ "Low TFs per Target",
+ "High TFs per Target")
+ stabilityDF <- merge(stabilityDF, dfTargetSorted, by.x = targetCol, by.y = "Target")
+ stabilityDF$targetLabel <- paste0(stabilityDF[[targetCol]], "(", stabilityDF$tfCount, ")")
+ # Keep plotting order
+# stabilityDF$targetLabel <- factor(stabilityDF$targetLabel,
+# levels = unique(stabilityDF$targetLabel))
+ # Order based on tfCount directly
+ stabilityDF$targetLabel <- factor(
+ stabilityDF$targetLabel,
+ levels = unique(stabilityDF$targetLabel[order(stabilityDF$tfCount)])
+ )
+
+ # Plot: boxplot per target
+ p3 <- ggplot(stabilityDF, aes(x = targetLabel, y = Stability, fill = Group)) +
+ geom_boxplot(outlier.size = 0.5, alpha = 0.8) +
+ labs(title = paste("Per-Target Stability Distribution -", typeName),
+ x = "Number of TFs per Target (TF count)", y = "Stability") +
+ theme_minimal() +
+ theme_bw(base_family = "Helvetica") +
+ theme(
+ panel.grid.major.y = element_line(color = "grey80", linewidth = 0.3),
+ panel.grid.minor = element_blank(),
+ # axis.line = element_line(color = "black", linewidth = 0.4),
+ axis.text.x = element_text(size = 7, angle = 90, vjust = 0.5),
+ axis.text.y = element_text(size = 7),
+ axis.title = element_text(size = 9),
+ plot.margin = margin(5, 5, 5, 5),
+ legend.position = "top"
+ # panel.background = element_rect("black", fill = NA)
+ )
+
+ # Save
+ ggsave(file.path(basePath, paste0("Top", nSelect, "HighorLow_inDegreeGenes_Boxplot", typeName, ".pdf")),
+ plot = p3, width = 12, height = 5)
+
+}
+
diff --git a/evaluation/R/evaluateNetworks1.R b/evaluation/R/evaluateNetworks1.R
new file mode 100755
index 0000000..f25e504
--- /dev/null
+++ b/evaluation/R/evaluateNetworks1.R
@@ -0,0 +1,708 @@
+rm(list = ls())
+suppressPackageStartupMessages({
+ library(dplyr)
+ library(tidyr)
+ library(ggplot2)
+ library(ComplexHeatmap)
+ library(pheatmap)
+ library(circlize)
+ # library(VennDiagram)
+# library(ggVennDiagram)
+ library(grid)
+ library(reshape2)
+ library(RColorBrewer)
+ library(purrr)
+
+})
+
+source("/data/miraldiNB/Michael/Scripts/VennDiagram.R")
+source("/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/RprogUtils/evaluateNetUtils.R")
+
+# dirOut <- "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/Bulk/5kbTSS/newPipe/spikeIgnored/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63"
+dirOut <- "/data/miraldiNB/Michael/mCD4T_Wayman/Figures/Fig4/networkComparison1"
+dir.create(dirOut, recursive = T, showWarnings = F)
+
+# Combined/combined_max.tsv
+# TFmRNA/edges_subset.txt
+# netFiles <- list(
+# "TFA" = "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/SC/SCENIC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5_eRegulon/TFA/edges_subset.tsv",
+# "TFmRNA" = "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/SC/SCENIC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5_eRegulon/TFmRNA/edges_subset.tsv",
+# "Combined" = "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/SC/SCENIC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5_eRegulon/Combined/combined_max.tsv"
+# )
+
+# netFiles <- list(
+# "TFA" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.tsv",
+# "TFmRNA" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.tsv",
+# "Combined" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/combined_max.tsv"
+# )
+
+netFiles <- list(
+ "PB" = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/combined_max.tsv",
+ "SC" = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/Inferelator/noMergedTF/SC/ATAC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5/Combined/combined_max.tsv"
+ # "MC2" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/metaCells/MC2/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63_logNorm/Combined/combined_max.tsv",
+ # "SEA" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/metaCells/SEACells/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/combined_max.tsv"
+)
+
+# priorFile = "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/MotifScan5kbTSS_b_sp.tsv"
+priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10_sp.tsv"
+# k <- read.table("/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/SCENICp/SCENICp_RE2Glinks_FIMO_5Kb_derived_b.tsv")
+# k$Target <- rownames(k)
+# b <- melt(k, id.vars = "Target", variable.names = "TF", value.name = "Weigths")
+# colnames(b)[2:3] <- c("TF", "Weights")
+# b <- b[, c("TF", "Target", "Weights")]
+# bb <- b[b$Weights != 0, ]
+# write.table(bb, "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/SCENICp/SCENICp_RE2Glinks_FIMO_5Kb_derived_sp.tsv", row.names = F, sep ="\t", quote = F)
+
+tfCol <- "TF" # specify TF column name here
+targetCol <- "Gene" # specify Target column name here
+priorTfCol <- "TF" # TF column name in prior file
+priorTargetCol <- "Target" # Target column name in prior file
+priorWgtCol <-"Weight"
+nSelect <- 10
+stabCol <- "Stability"
+rankCol <- "signedQuantile"
+compareNets <- TRUE
+tfList <- NULL
+lineageTFs <-c("Bcl6","Maf","Batf","Tox","Ascl2", "Foxp3","Ikzf2","Rorc", "Rorc","Stat3",
+ "Tbx21","Stat4","Eomes","Prdm1","Tcf7","Klf2","Lef1")
+
+# lineageTFs <- c("TCF7","LEF1","ID3","KLF2","PRDM1","ZEB2","RUNX3","EOMES",
+# "TBX21","STAT1","STAT4","HIF1A","RORC","RORA","STAT3","IRF4",
+# "BATF","FOXP3","IKZF2","IKZF4","STAT5","FOXO1","GATA3","STAT6",
+# "CIITA","RFX5","RFXAP","RFXANK","NLRC5")
+N <- NULL
+# file_k_clust_across <- "/Users/owop7y/Desktop/MiraldiLab/PROJECTS/GRN/Benchmark/mCD4T/Figures/Fig4/networkComparison/Kmeans_Jaccard_Edges_perTF.rds"
+file_k_clust_across <- NULL
+
+# ========================================================================
+# ------- Read Input files
+# ========================================================================
+
+# ----- Read Prior
+priorData <- read.table(priorFile, header = TRUE, sep = "\t", stringsAsFactors = FALSE)
+head(priorData, 3)
+priorData <- priorData[priorData[[priorWgtCol]] != 0, ]
+priorPairs <- unique(paste(priorData[[priorTfCol]], priorData[[priorTargetCol]], sep = "_"))
+
+# --- Preload all network data into a list ---
+netDataList <- list()
+for (typeName in names(netFiles)) {
+ filePath <- netFiles[[typeName]]
+ netDataList[[typeName]] <- read.table(filePath, header = TRUE, sep = "\t", stringsAsFactors = FALSE)
+}
+
+numNetworks <- length(netDataList)
+networkNames <- names(netDataList)
+
+# ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
+# ---- PART ONE: Make a table of number of network ppts - uniqueTFs, uniqueTargets, Total Interactions, and % supported by prior
+# ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
+# process each network
+summaryList <- list()
+for (typeName in networkNames) {
+ data <- netDataList[[typeName]]
+
+ uniqueTf <- length(unique(data[[tfCol]]))
+ uniqueTarget <- length(unique(data[[targetCol]]))
+
+ pairStrings <- unique(paste(data[[tfCol]], data[[targetCol]], sep = "_"))
+ totalInteractions <- length(pairStrings)
+
+ supportedCount <- length(intersect(pairStrings, priorPairs))
+ percentSupportedByPrior <- 100 * supportedCount / totalInteractions
+
+ summaryList[[typeName]] <- data.frame(type = typeName, uniqueTf = uniqueTf,
+ uniqueTarget = uniqueTarget, totalInteractions = totalInteractions,
+ percentSupportedByPrior = percentSupportedByPrior)
+}
+
+summaryTable <- do.call(rbind, summaryList)
+print(summaryTable)
+
+write.table(as.data.frame(summaryTable), file.path(dirOut, "summaryStatistics.tsv"), col.names = NA, row.names = TRUE, quote = FALSE, sep = "\t")
+
+# ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
+# ---- PART TWO: Compare Networks (can handle more than 2)
+# ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
+
+# ──── Compare TF-Targets across methods
+# Initialize a data frame to store results
+tfTargetCounts <- data.frame(TF = unique(unlist(lapply(netDataList, function(x) unique(x$TF)))))
+
+# Count targets for each TF in each network
+for(netName in names(netDataList)) {
+ netData <- netDataList[[netName]]
+
+ # Count number of unique target genes per TF
+ targetCounts <- netData %>%
+ group_by(TF) %>%
+ summarise(targetCount = n())
+
+ colname <- netName
+ tfTargetCounts <- tfTargetCounts %>%
+ left_join(targetCounts, by = "TF") %>%
+ rename(!!colname := targetCount)
+}
+# Replace NA with 0 for TFs not present in a network
+tfTargetCounts[is.na(tfTargetCounts)] <- 0
+
+# Build file suffix
+fileSuffix <- if(!is.null(topNpercent)) {
+ paste0(rankCol, "_top", topNpercent, "pct")
+} else {
+ paste0(rankCol, "_all")
+}
+networkNames <- names(netDataList)
+numNetworks <- length(netDataList)
+
+edgeSets <- getEdgeSets(netDataList, tfCol, targetCol, rankCol)
+
+# -------- unique-edge counts ----------------------------------------- ##
+uniqCountsDf <- data.frame(
+ network = names(edgeSets),
+ numUniqueEdges = sapply(seq_along(edgeSets), function(k){
+ curr <- edgeSets[[k]]
+ others <- unique(unlist(edgeSets[-k]))
+ sum(!curr %in% others)
+ }),
+ row.names = NULL,
+ stringsAsFactors = FALSE
+ )
+
+write.table(uniqCountsDf,
+ file.path(dirOut, paste0("num_UniqueEdges_", fileSuffix, ".tsv")),
+ quote=F, col.names=NA, row.names=TRUE, sep="\t")
+
+# ------- Plot global Jaccard heatmap
+# pairs <- plotGlobalJaccard(edgeSets, fileOut = file.path(dirOut, paste0("Global_Jaccard_",fileSuffix, ".pdf")), fig_width = 2.5, fig_height = 2.5)
+# pairInt <- pairs[["pairInt"]]
+
+plotSimilarityHeatmap(pairs[["pairJac"]], fileOut=file.path(dirOut, paste0("Global_Jaccard_",fileSuffix, ".pdf")), fontsize_number=7, fontsize=9, fig_width=2.5, fig_height=2.5)
+pairDf <- melt(pairs[["pairInt"]], varnames = c("network1", "network2"), value.name = "numIntersection", na.rm = TRUE)
+write.table(pairDf,file.path(dirOut, paste0("num_Pairwise_Intersection_", fileSuffix, ".tsv")), quote=F, col.names=NA, row.names=TRUE, sep="\t")
+
+
+# --------- Compute Jaccard for multiple topN% -------------
+topNpercentList <- seq(10, 100, by = 10)
+jaccardResults <- list()
+for (topN in topNpercentList) {
+ currEdgeSets <- getEdgeSets(netDataList, tfCol, targetCol, rankCol, "topNpercent", topN)
+ # Compute pairwise Jaccard
+ fileSuffix <- paste0(rankCol, "_top", topN, "pct")
+ pairs <- computeJaccardMatrix(currEdgeSets)
+ pairJac <- pairs[["pairJac"]]
+ # Convert pairwise matrix to long format
+ pairJacDf <- as.data.frame(as.table(pairJac))
+ colnames(pairJacDf) <- c("Network1", "Network2", "Jaccard")
+ pairJacDf$TopNpercent <- topN
+ jaccardResults[[as.character(topN)]] <- pairJacDf
+}
+# Combine all percentages
+jaccardDf <- do.call(rbind, jaccardResults)
+# Filter out self-comparisons
+jaccardDf <- jaccardDf[jaccardDf$Network1 != jaccardDf$Network2, ]
+jaccardDf<- jaccardDf %>%
+ dplyr::mutate(
+ Network1 = as.character(Network1),
+ Network2 = as.character(Network2),
+ pairID = ifelse(Network1 < Network2,
+ paste(Network1, Network2, sep = "-"),
+ paste(Network2, Network1, sep = "-"))
+ ) %>%
+ dplyr::group_by(TopNpercent, pairID) %>%
+ dplyr::slice(1) %>%
+ dplyr::ungroup()
+write.table(jaccardDf,file.path(dirOut, "Jaccard_Sim.tsv"), quote=F, col.names=NA, row.names=TRUE, sep="\t")
+
+# Create color mapping based on Set1 palette
+uniquePairs <- sort(unique(jaccardDf$pairID))
+colors <- c("#449B75" ,"#FFE528", "#999999", "#E41A1C", "#AC5782", "#C66764")
+comparisonColor = setNames(colors, uniquePairs)
+
+# Plot line graph with ggplot2
+p_jac <- ggplot(jaccardDf, aes(x = TopNpercent, y = Jaccard,color = pairID,group = pairID)) +
+ geom_line(linewidth = 0.5) +
+ geom_point(size = 1) +
+ scale_color_manual(values = comparisonColor) +
+ scale_x_continuous(breaks = seq(10, 100, by = 10)) +
+ scale_y_continuous(limits = c(0,1)) +
+ labs(
+ x = "Top N% edges",
+ y = "Jaccard similarity",
+ color = "Network pair"
+ ) +
+ theme_bw(base_family = "Helvetica") +
+ theme(
+ axis.text.x = element_text(size = 7, color = "black"),
+ axis.text.y = element_text(size = 7, color = "black"),
+ axis.title = element_text(size = 9, color = "black"),
+ legend.title = element_blank(),
+ legend.text = element_text(size = 9, color = "black"),
+ legend.position = "right",
+ legend.box.just = "left", # centers the legend
+ panel.grid.major = element_line(color = "grey80", linewidth = 0.3),
+ panel.grid.minor = element_line(color = "grey90", linewidth = 0.1)
+ )
+ggsave(file.path(dirOut, "linePlot_JaccardSim.pdf"),plot = p_jac, width = 4, height = 3, dpi = 600)
+
+#plotEdgeSharing(edgeSets, fileSuffix, outDir)
+
+# --- 3. TF-centric analysis
+# For each TF, how consistent are its predicted targets across networks?
+# Biological networks are TF-driven, so checking target consistency per TF
+tfMatrix <- computeTFJaccard(netDataList, tfList)
+# MAke Heatmap
+getPalette = colorRampPalette(brewer.pal(13, "Set1"))
+comparisonColor = getPalette(length(colnames(tfMatrix)))
+comparisonColor = setNames(comparisonColor, colnames(tfMatrix))
+colAnn <- HeatmapAnnotation(
+ `Comparison` = colnames(tfMatrix),
+ col = list('Comparison' = comparisonColor),
+ annotation_legend_param = list(
+ Comparison = list(
+ title_gp = gpar(fontsize = 8, fontface = "plain"), # title size
+ labels_gp = gpar(fontsize = 6) # category label size
+ )),
+ simple_anno_size = unit(3, "mm"),
+ show_annotation_name = FALSE
+ )
+
+
+tfMatrix[is.na(tfMatrix)] <- 0
+tfRobustness <- rowMeans(tfMatrix, na.rm = TRUE) # average Jaccard per TF
+
+k_center = 6
+if (is.null(file_k_clust_across)){
+ k_clust_across <- kmeans(tfMatrix, centers=k_center,nstart=20, iter.max = 50)
+}else{
+ k_clust_across <- readRDS(file_k_clust_across)
+}
+split_by <- factor(k_clust_across$cluster, levels = c(1:6))
+
+# Add row markers (labels with lines)
+row_mark <- rowAnnotation(
+ mark = anno_mark(
+ at = which(rownames(tfMatrix) %in% lineageTFs),
+ labels = lineageTFs,
+ labels_gp = gpar(fontsize = 6, fontface = "plain"),
+ padding = unit(0.3, "mm"),
+ side = "left", # line starts from left of heatmap
+ )
+)
+
+annot_cols <- c('1'='#fb940b','2'='#01cc00','3'='#2085ec','4'='#fe98bf','5'='#ad7a5b', '6'='turquoise4')
+df_clust <- data.frame(across=k_clust_across$cluster)
+df_clust$across <- factor(df_clust$across, levels=1:k_center, ordered=T)
+rowAnn_left <- HeatmapAnnotation(`Cluster`= df_clust$across,
+ col = list('Cluster'= annot_cols),
+ which = 'row',
+ simple_anno_size = unit(2, 'mm'),
+ show_legend = FALSE,
+ show_annotation_name = F)
+
+# robustCols <- colorRamp2(c(min(tfRobustness), max(tfRobustness)), c("lightgrey","blue"))
+# robustCols <- colorRamp2( seq(from = min(tfRobustness), to = max(tfRobustness), length.out = 200),inferno(200))
+robustCols <- colorRamp2(seq(min(tfRobustness), max(tfRobustness), length.out = 100), colorRampPalette(brewer.pal(9, "Blues"))(100))
+rowRobustness <- rowAnnotation(
+ Robustness = anno_simple(
+ tfRobustness,
+ col = robustCols,,
+ border = TRUE
+ ),
+ show_annotation_name = F,
+ width = unit(1.5, "mm")
+)
+
+# Create a separate legend for robustness
+robustLegend <- Legend(
+ title = "TF Robustness",
+ col_fun = robustCols,
+ title_gp = gpar(fontsize = 8, fontface = "plain"), # unbold
+ labels_gp = gpar(fontsize = 6) # optional: label size
+)
+
+# dirOut <- "/data/MiraldiLab/team/Michael/mCD4T_WaymanFigures/Fig4/networkComparison/"
+# dir.create(dirOut, recursive = T, showWarnings = F)
+
+heatCols <- colorRampPalette(RColorBrewer::brewer.pal(9,"YlOrRd"))(100)
+pdf(file.path(dirOut, "Jaccard_Edges_perTF1.pdf"), width = 2.3, height = 4.5, compress = T)
+ht1 <- Heatmap(tfMatrix,
+ name='Jaccard Similarity',
+ show_row_names = F,
+ row_split=split_by,
+ row_gap = unit(0, "mm"),
+ column_gap = unit(0, "mm"),
+ border = TRUE,
+ row_title = NULL,
+ row_dend_reorder = F,
+ use_raster = F,
+ col = heatCols,
+ # clustering settings
+ cluster_rows = TRUE, # allow hierarchical clustering
+ cluster_row_slices = TRUE, # reorder within each k-means cluster
+ cluster_columns = FALSE,
+ cluster_column_slices = F,
+ # column_order = colnames(tfa_z_across),
+ show_row_dend = FALSE, # show dendrogram within slices
+ show_column_dend = FALSE,
+ show_column_names = FALSE,
+ column_names_side = 'top',
+ # column_names_rot = 45,
+ column_title = NULL,
+ column_split = colnames(tfMatrix),
+ top_annotation = colAnn,
+ left_annotation = rowAnn_left,
+ heatmap_legend_param = list(
+ direction = "horizontal",
+ title = "Jaccard",
+ title_position = "topcenter",
+ title_gp = gpar(fontsize = 8, fontface = "plain"), # legend title size
+ labels_gp = gpar(fontsize = 6), # numbers/tick labels size
+ legend_width = unit(2.5, "cm"),
+ legend_height = unit(0.5, "cm")
+ )
+ )
+draw(row_mark+ht1+rowRobustness , heatmap_legend_side = "bottom")
+dev.off()
+
+ht2 <- draw(ht1)
+row_order_list <- row_order(ht2)
+saveRDS(row_order_list, file.path(dirOut, "final_row_order.rds") )
+# Save Legend as PDF
+pdf(file.path(dirOut, "TF_Robustness_Legend.pdf"), width = 2, height = 2)
+draw(robustLegend)
+dev.off()
+saveRDS(k_clust_across, file.path(dirOut, "Kmeans_Jaccard_Edges_perTF.rds"))
+
+
+# ---------- Check
+clust_df <- df_clust
+clust_df$TF <- rownames(clust_df)
+colnames(clust_df)[1] <- "cluster"
+tfTargetCounts <- tfTargetCounts %>%
+ left_join(clust_df, by = "TF")
+
+# Calculate average targets across all methods
+tfTargetCounts$avgTargets <- rowMeans(tfTargetCounts[, names(netDataList)])
+tfTargetCounts$medianTargets <- apply(tfTargetCounts[, names(netDataList)],1,median, na.rm = T)
+# # Summary statistics by cluster
+# clusterSummary <- tfTargetCounts %>%
+# group_by(cluster) %>%
+# summarise(
+# meanTargets = mean(avgTargets),
+# medianTargets = median(avgTargets),
+# sdTargets = sd(avgTargets),
+# nTFs = n()
+# )
+
+write.table(tfTargetCounts, file.path(dirOut, "TF_avgTargetCounts_byRobustnessCluster.tsv"), quote=F, col.names=NA, row.names=TRUE, sep="\t")
+
+pAvg <- ggplot(tfTargetCounts, aes(x = cluster, y = avgTargets, fill = cluster)) +
+ geom_boxplot(outlier.color = "red", outlier.size = 1,outlier.shape = 19, outlier.alpha = 0.7) +
+ geom_jitter(width = 0.2, size = 0.1) +
+ labs(title = "",
+ y = "Average # of Targets",
+ x = "") +
+ scale_fill_manual(values = annot_cols) +
+ theme_minimal() +
+ theme_bw(base_family = "Helvetica") +
+ theme(
+ axis.text.y = element_text(size = 7, color = "black"),
+ axis.text.x = element_blank(),
+ axis.ticks.x = element_blank() ,
+ axis.title = element_text(size = 9, color = "black"),
+ legend.position = "none",
+ panel.grid.major = element_line(color = "grey80", linewidth = 0.3),
+ panel.grid.minor = element_line(color = "grey90", linewidth = 0.1)
+ )
+ggsave(file.path(dirOut, "TF_avgTargetCounts_Distribution.pdf"), pAvg, width = 2.5, height = 2, dpi = 600)
+
+
+pMedian <- ggplot(tfTargetCounts, aes(x = cluster, y = medianTargets, fill = cluster)) +
+ geom_boxplot(outlier.color = "red", outlier.size = 1,outlier.shape = 19, outlier.alpha = 0.7) +
+ geom_jitter(width = 0.2, size = 0.1) +
+ labs(title = "",
+ y = "Median # of Targets",
+ x = "") +
+ scale_fill_manual(values = annot_cols) +
+ theme_minimal() +
+ theme_bw(base_family = "Helvetica") +
+ theme(
+ axis.text.y = element_text(size = 7, color = "black"),
+ axis.text.x = element_blank(),
+ axis.ticks.x = element_blank() ,
+ axis.title = element_text(size = 9, color = "black"),
+ legend.position = "none",
+ panel.grid.major = element_line(color = "grey80", linewidth = 0.3),
+ panel.grid.minor = element_line(color = "grey90", linewidth = 0.1)
+ )
+ggsave(file.path(dirOut, "TF_medianTargetCounts_Distribution.pdf"), pMedian, width = 2.5, height = 2, dpi = 600)
+
+
+# tf_stats <- lapply(names(netDataList), function(nm){
+# df <- netDataList[[nm]]
+# df %>%
+# count(!!sym(tfCol)) %>%
+# rename(TF = !!sym(tfCol), !!nm := n)
+# }) %>%
+# reduce(full_join, by = "TF") %>%
+# mutate(cluster = k_clust_across$cluster[TF])
+#
+# # Which cluster have higher mean degree and which has lower mean degree
+# cluster_summary <- tf_stats %>%
+# group_by(cluster) %>%
+# summarise(across(starts_with("PB") | starts_with("SC") | starts_with("MC") | starts_with("SEA"),
+# list(mean = ~mean(.x, na.rm=TRUE),
+# sd = ~sd(.x, na.rm=TRUE))),
+# n_TFs = n())
+#
+# ## Only Lineage TFs
+# tfMatrix_lineage <- tfMatrix[intersect(lineageTFs, rownames(tfMatrix)), ]
+# k_clust_lineage<- kmeans(tfMatrix_lineage, centers=4,nstart=20, iter.max = 50)
+# split_by_lineage <- factor(k_clust_lineage$cluster, levels = c(1:4))
+#
+# pdf(file.path(dirOut, "Jaccard_Edges_perTF_lineage.pdf"), width = 2, height = 3, compress = T)
+# ht <- Heatmap(tfMatrix_lineage,
+# name = 'Jaccard Similarity',
+# show_row_names = TRUE,
+# row_names_gp = gpar(fontsize = 6),
+# row_names_side = "left", # rownames on the left
+# row_split = split_by_lineage,
+# row_gap = unit(0, "mm"),
+# column_gap = unit(0, "mm"),
+# border = TRUE,
+# row_title = NULL,
+# row_dend_reorder = FALSE,
+# use_raster = FALSE,
+# col = heatCols,
+# cluster_rows = TRUE,
+# cluster_row_slices = TRUE,
+# cluster_columns = FALSE,
+# cluster_column_slices = FALSE,
+# show_row_dend = FALSE,
+# show_column_dend = FALSE,
+# show_column_names = FALSE,
+# column_names_side = 'top',
+# column_title = NULL,
+# column_split = colnames(tfMatrix_lineage),
+# top_annotation = colAnn,
+# heatmap_legend_param = list(
+# direction = "horizontal",
+# title = "Jaccard Similarity",
+# title_position = "topcenter",
+# title_gp = gpar(fontsize = 8, fontface = "plain"), # legend title size
+# labels_gp = gpar(fontsize = 6), # numbers/tick labels size
+# legend_width = unit(2.5, "cm"),
+# legend_height = unit(0.5, "cm")
+# )
+# )
+# draw(ht, heatmap_legend_side = "bottom")
+# dev.off()
+# saveRDS(k_clust_lineage, file.path(dirOut, "Kmeans_Jaccard_Edges_perTF_lineage.rds"))
+
+# --- 4. Aggregated network-pair Jaccard
+fun = median
+fun_name = "median" # returns "mean"
+aggMat <- computeAggregatedPairJaccard(tfMatrix, fun = fun)
+plotAggregatedJaccard(aggMat,file.path(dirOut, paste0("AggregatedTF_Jaccard_", fun_name, ".pdf")), fig_width = 2.5, fig_height = 2.5)
+
+fun = mean
+fun_name = "mean" # returns "mean"
+aggMat <- computeAggregatedPairJaccard(tfMatrix, fun = fun)
+plotAggregatedJaccard(aggMat,file.path(dirOut, paste0("AggregatedTF_Jaccard_", fun_name, ".pdf")), fig_width = 2.5, fig_height = 2.5)
+
+
+# ----- Are top regulators (high-degree TFs) stable across pseudobulk / single-cell / metacell networks.
+hubSim <- computeHubOverlapHeatmap(
+ netDataList = netDataList, tfCol = tfCol, networkNames = networkNames,
+ dirOut = dirOut, topN = 50,
+ metric = "jaccard", # or "overlap"
+ fontsize = 9, fontsize_number = 7, fig_width = 2.5, fig_height = 2.5
+)
+
+hubSim <- computeHubOverlapHeatmap(
+ netDataList = netDataList, tfCol = tfCol, networkNames = networkNames,
+ dirOut = dirOut, topN = 50,
+ metric = "overlap", # or "jaccard"
+ fontsize = 9, fontsize_number = 7, fig_width = 2.5, fig_height = 2.5
+)
+
+# ────────────────────────────────────────────────────────────────────────────────────────────
+# PART THREE: Histogram distributions of:
+ # A: Targets per TF: # times each gene is targeted (number of TFs regulating it).
+ # B: TFs per Target: # times each TF is a regulator (number of genes it controls)
+ # C: Box-Plot of Stability of the top N low and high in degree genes
+
+ #NOTE: Each Figure is saved in the same path as the network being evaluated
+# ────────────────────────────────────────────────────────────────────────────────────────────
+# Generate both plots for each network
+for (typeName in names(netDataList)) {
+ data <- netDataList[[typeName]]
+ basePath <- dirname(netFiles[[typeName]])
+
+ # Plot 1: # Targets per TF
+ # tfTargetCounts <- table(data[[tfCol]])
+ # dfTF <- data.frame(TF = names(tfTargetCounts), targetCount = as.integer(tfTargetCounts))
+ dfTF <- tfTargetCounts[, c("TF", tfCol), drop = F]
+ colnames(dfTF)[2] <- "targetCount"
+
+ p1 <- ggplot(dfTF, aes(x = targetCount)) +
+ geom_histogram(binwidth = 1, boundary = 0.5, fill = "#0072B2", color = "black", alpha = 0.8) +
+ # scale_x_continuous(breaks = seq(0, max(dfTF$targetCount), 1)) +
+ labs(title = paste("Distribution of # Targets per TF -", typeName),
+ x = "# Targets", y = "# TFs") +
+ theme_bw(base_family = "Helvetica") +
+ theme(
+ panel.grid.major.y = element_line(color = "grey80", linewidth = 0.3),
+ panel.grid.minor = element_blank(),
+ # axis.line = element_line(color = "black", linewidth = 0.4),
+ axis.text.x = element_text(size = 7),
+ axis.text.y = element_text(size = 7),
+ axis.title = element_text(size = 9),
+ plot.margin = margin(5, 5, 5, 5),
+ # panel.background = element_rect("black", fill = NA)
+ )
+
+ ggsave(file.path(basePath, paste0("TargetCountPerTF_", typeName, ".pdf")),
+ plot = p1, width = 3, height = 3)
+
+ # Plot 2: # TFs per Target
+ targetTFCounts <- table(data[[targetCol]])
+ dfTarget<- data.frame(Target = names(targetTFCounts), tfCount = as.integer(targetTFCounts))
+ dfTargetSorted <- dfTarget[order(dfTarget$tfCount), ]
+
+ p2 <- ggplot(dfTarget, aes(x = tfCount)) +
+ geom_histogram(binwidth = 1, boundary = 0.5, fill = "blue", color = "black", alpha = 0.7) +
+ # scale_x_continuous(breaks = seq(0, max(dfTarget$tfCount), 1)) +
+ labs(title = paste("Distribution of # TFs per Target -", typeName),
+ x = "# TFs", y = "# Targets") +
+ theme_bw(base_family = "Helvetica") +
+ theme(
+ panel.grid.major.y = element_line(color = "grey80", linewidth = 0.3),
+ panel.grid.minor = element_blank(),
+ # axis.line = element_line(color = "black", linewidth = 0.4),
+ axis.text.x = element_text(size = 7),
+ axis.text.y = element_text(size = 7),
+ axis.title = element_text(size = 9),
+ plot.margin = margin(5, 5, 5, 5),
+ # panel.background = element_rect("black", fill = NA)
+ )
+
+ ggsave(file.path(basePath, paste0("TFCountPerTarget_", typeName, ".pdf")),
+ plot = p2, width = 3, height = 3)
+
+ # Select top N low and high TF targets
+ lowTargs <- dfTargetSorted$Target[1:nSelect]
+ highTargs <- tail(dfTargetSorted, nSelect)$Target
+
+ stabilityDF <- data[data[[targetCol]] %in% c(lowTargs, highTargs), c(targetCol, stabCol)]
+ # Add Group column based on whether the target is in lowTargs or highTargs
+ stabilityDF$Group <- ifelse(stabilityDF[[targetCol]] %in% lowTargs,
+ "Low TFs per Target",
+ "High TFs per Target")
+ stabilityDF <- merge(stabilityDF, dfTargetSorted, by.x = targetCol, by.y = "Target")
+ stabilityDF$targetLabel <- paste0(stabilityDF[[targetCol]], "(", stabilityDF$tfCount, ")")
+ # Keep plotting order
+# stabilityDF$targetLabel <- factor(stabilityDF$targetLabel,
+# levels = unique(stabilityDF$targetLabel))
+ # Order based on tfCount directly
+ stabilityDF$targetLabel <- factor(
+ stabilityDF$targetLabel,
+ levels = unique(stabilityDF$targetLabel[order(stabilityDF$tfCount)])
+ )
+
+ # Plot: boxplot per target
+ p3 <- ggplot(stabilityDF, aes(x = targetLabel, y = Stability, fill = Group)) +
+ geom_boxplot(outlier.size = 0.5, alpha = 0.8) +
+ labs(title = paste("Per-Target Stability Distribution -", typeName),
+ x = "Number of TFs per Target (TF count)", y = "Stability") +
+ theme_minimal() +
+ theme_bw(base_family = "Helvetica") +
+ theme(
+ panel.grid.major.y = element_line(color = "grey80", linewidth = 0.3),
+ panel.grid.minor = element_blank(),
+ # axis.line = element_line(color = "black", linewidth = 0.4),
+ axis.text.x = element_text(size = 7, angle = 90, vjust = 0.5),
+ axis.text.y = element_text(size = 7),
+ axis.title = element_text(size = 9),
+ plot.margin = margin(5, 5, 5, 5),
+ legend.position = "top"
+ # panel.background = element_rect("black", fill = NA)
+ )
+
+ # Save
+ ggsave(file.path(basePath, paste0("Top", nSelect, "HighorLow_inDegreeGenes_Boxplot", typeName, ".pdf")),
+ plot = p3, width = 12, height = 5)
+
+}
+
+
+
+
+#
+
+
+
+#
+# library(reshape2)
+# library(ggplot2)
+#
+# # Convert to long format for ggplot
+# df_long <- melt(tfa_by_celltype)
+# colnames(df_long) <- c("TF", "CellType", "TFA")
+#
+# # Add a flag for lineage TFs
+# df_long$Lineage <- mapply(function(tf, ct) {
+# if(tf %in% lineage_TFs[[ct]]) "Lineage" else "Other"
+# }, df_long$TF, df_long$CellType)
+#
+# # Boxplot or violin plot
+# ggplot(df_long, aes(x = CellType, y = TFA, fill = Lineage)) +
+# geom_violin(alpha = 0.6) +
+# geom_boxplot(width=0.1, outlier.shape=NA) +
+# scale_fill_manual(values = c("Lineage"="red", "Other"="grey")) +
+# theme_minimal() +
+# theme(axis.text.x = element_text(angle=45, hjust=1)) +
+# labs(title="Lineage TF enrichment in TFA", y="Mean TFA")
+#
+#
+# lineage_TFs <- list(
+# Tfh10 = c("Bcl6","Maf","Batf"),
+# Tfh_Int = c("Bcl6","Maf","Batf"),
+# Tfh = c("Bcl6","Maf"),
+# Tfr = c("Foxp3","Bcl6"),
+# cTreg = c("Foxp3","Ikzf2"),
+# eTreg = c("Foxp3","Ikzf2"),
+# rTreg = c("Foxp3","Ikzf2"),
+# Treg_Rorc = c("Foxp3","Rorc"),
+# Th17 = c("Rorc","Batf","Stat3"),
+# Th1 = c("Tbx21","Stat4"),
+# CTL_Prdm1 = c("Tbx21","Eomes","Prdm1"),
+# CTL_Bcl6 = c("Tbx21","Eomes","Prdm1"),
+# TEM = c("Eomes","Tcf7","Klf2"),
+# TCM = c("Eomes","Tcf7","Klf2")
+# )
+#
+
+# lineage_tfs_list <- list(
+# TCM = c("TCF7", "LEF1", "ID3", "KLF2"),
+# TEM = c("PRDM1", "ZEB2", "RUNX3", "EOMES"),
+# Th1 = c("TBX21", "STAT1", "STAT4", "RUNX3", "HIF1A"),
+# Th17 = c("RORC", "RORA", "STAT3", "IRF4", "BATF"),
+# Treg = c("FOXP3", "IKZF2", "IKZF4", "STAT5"),
+# Naive = c("TCF7", "LEF1", "KLF2", "FOXO1"),
+# Th2 = c("GATA3", "STAT5", "STAT6", "IRF4"),
+# CTL = c("EOMES", "TBX21", "RUNX3", "ZEB2", "PRDM1"),
+# MHCII = c("CIITA", "RFX5", "RFXAP", "RFXANK", "NLRC5")
+# )
+
+#
+#
+# # Distribution per F1.
+# # Also show F1 score based on heatmap
+# ggplot(pr_df, aes(x = network, y = F1)) +
+# geom_boxplot(outlier.size = 0.8) +
+# geom_jitter(width = 0.15, alpha = 0.6, size = 1) +
+# theme_minimal() +
+# theme(axis.text.x = element_text(angle=45, hjust=1)) +
+# ylab("F1 score") + xlab("Representation")
diff --git a/evaluation/R/histogramConfidences.R b/evaluation/R/histogramConfidences.R
new file mode 100755
index 0000000..c233cf0
--- /dev/null
+++ b/evaluation/R/histogramConfidences.R
@@ -0,0 +1,194 @@
+# Plot weights/confidences
+rm(list=ls())
+library(ggplot2)
+
+# Custom function to format y-axis labels
+custom_labels <- function(x) {
+ # Keep 0 as is, scale other values by 1000 and add "K"
+ labels <- ifelse(x == 0, "0", paste0(scales::number(x / 1000, accuracy = 0.1), "K"))
+ return(labels)
+}
+
+histogramConfidencesDir <- function(currNetDirs, breaks) {
+ # Ensure breaks length matches currNetDirs length
+ if (length(currNetDirs) != length(breaks)) {
+ stop("Length of breaks must match length of currNetDirs.")
+ }
+
+ for (ix in seq_along(currNetDirs)) {
+ currNetDir <- currNetDirs[ix]
+ subfolders <- list.dirs(currNetDir, recursive = FALSE)
+ target_folders <- subfolders[basename(subfolders) %in% c("TFA", "TFmRNA")]
+
+ # Load and plot individually
+ for (subfolder in target_folders) {
+ # subfolder = target_folders[jx]
+ file_path <- file.path(subfolder, "edges_subset.txt")
+ if (file.exists(file_path)) {
+ data <- read.table(file_path, header = TRUE, sep = "\t")
+
+ # Create individual histogram
+ conf <- data.frame(confidence = floor(data$Stability) / max(floor(data$Stability)))
+ p <- ggplot((conf), aes(x = confidence)) +
+ geom_histogram(bins = breaks[ix], fill = "blue", color = "black", alpha = 0.9) +
+ labs(title = basename(subfolder), x = "Confidence", y = "Frequency") +
+ theme_bw() +
+ theme(
+ axis.title = element_text(size = 16, color = "black"),
+ axis.text = element_text(size = 14, color = "black"),
+ plot.title = element_text(size = 20,, color = "black", hjust = 0.5)
+ ) + scale_y_continuous(labels = custom_labels)#+ scale_y_continuous(labels = label_number(scale = 1e-3, suffix = "K"))
+
+ # Save individual plot
+ hist_file <- file.path(subfolder, paste0("confidence_distribution_", breaks[ix], ".png"))
+ ggsave(hist_file, plot = p, width = 6, height = 5, dpi = 600)
+ } else {
+ message("File not found: ", file_path)
+ }
+ }
+
+ }
+
+}
+
+
+histogramConfidencesData <- function(currNetFiles, breaks) {
+ # Ensure breaks length matches currNetFiles length
+ if (length(currNetFiles) > length(breaks)) {
+ stop("Length of breaks must match or be greater length of currNetFiles.")
+ }
+
+ # Loop over the directories in currNetFiles
+ for (ix in seq_along(currNetFiles)) {
+ currNetFile <- currNetFiles[ix] # Full path to edges_subset.txt
+
+ # Check if the file exists
+ if (file.exists(currNetFile)) {
+ data <- read.table(currNetFile, header = TRUE, sep = "\t")
+
+ # Create individual histogram
+ conf <- data.frame(confidence = floor(data$Stability) / max(floor(data$Stability)))
+ p <- ggplot(conf, aes(x = confidence)) +
+ geom_histogram(bins = breaks[ix], fill = "blue", color = "black", alpha = 0.9) +
+ labs(title = basename(dirname(currNetFile)), x = "Confidence", y = "Frequency") +
+ theme_bw() +
+ theme(
+ axis.title = element_text(size = 16, color = "black"),
+ axis.text = element_text(size = 14, color = "black"),
+ plot.title = element_text(size = 20, color = "black", hjust = 0.5)
+ ) + scale_y_continuous(labels = custom_labels) # Custom y-axis labels
+
+ # Save individual plot in the same directory
+ hist_file <- file.path(dirname(currNetFile), paste0("confidence_distribution_", breaks[ix], ".png"))
+ ggsave(hist_file, plot = p, width = 6, height = 5, dpi = 600)
+ } else {
+ message("File not found: ", currNetFile)
+ }
+ }
+}
+
+currNetFile <- "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/1KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT20/TFmRNA/edges_subset.txt"
+
+conf <- data.frame(confidence = floor(data$Stability))
+breaks <- max(conf)
+p <- ggplot(conf, aes(x = confidence)) +
+ geom_histogram(bins = breaks, fill = "blue", color = "black", alpha = 0.9) +
+ labs(title = basename(dirname(currNetFile)), x = "Confidence", y = "Frequency") +
+ theme_bw() +
+ theme(
+ axis.title = element_text(size = 16, color = "black"),
+ axis.text = element_text(size = 14, color = "black"),
+ plot.title = element_text(size = 20, color = "black", hjust = 0.5)
+ )
+
+ hist_file <- file.path(dirname(currNetFile), paste0("testConf_", breaks, ".png"))
+ ggsave(hist_file, plot = p, width = 6, height = 5, dpi = 600)
+
+# USAGE
+
+
+currNetDirs <- c(
+ # ---Pseudobulk Inferelator
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63",
+
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/Inferelator/ATAC_ChIPprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63",
+
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/Inferelator/ATAC_KOprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63",
+
+ # --- single cell inferelator
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/lambda0p25_220totSS_20tfsPerGene_subsamplePCT10",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/lambda0p5_220totSS_20tfsPerGene_subsamplePCT10",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/lambda1p0_220totSS_20tfsPerGene_subsamplePCT10",
+
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/SC/lambda0p25_220totSS_20tfsPerGene_subsamplePCT10",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/SC/lambda0p5_220totSS_20tfsPerGene_subsamplePCT10",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/SC/lambda1p0_220totSS_20tfsPerGene_subsamplePCT10",
+
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/SC/lambda0p25_220totSS_20tfsPerGene_subsamplePCT10",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/SC/lambda0p5_220totSS_20tfsPerGene_subsamplePCT10",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/SC/lambda1p0_220totSS_20tfsPerGene_subsamplePCT10",
+
+ # --- single cell downsampled
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/1KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT27",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/10KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT27",
+ "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/30KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT13"
+)
+
+
+# Extract subsample fractions
+# subsample_fracs <- sapply(currNetDirs, function(x) sub(".*subsampleFrac_([0-9p]+).*", "\\1", x))
+# subsample_fracs <- gsub("p", ".", subsample_fracs) # Convert 'p' to '.' for numeric comparison
+# unique_fracs <- unique(subsample_fracs)
+
+breaks = c(rep(150, 9), rep(500, 12))
+
+
+
+netFiles = c(
+ "1K" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/1KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT27/TFA/edges_subset.txt",
+ "10K" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/10KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT27/TFA/edges_subset.txt",
+ "30K" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/30KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT13/TFA/edges_subset.txt",
+ "77K" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/lambda0p5_220totSS_20tfsPerGene_subsamplePCT10/TFA/edges_subset.txt"
+ )
+
+histogramConfidencesStacked <- function(currNetFiles, breaks, dirOut, saveName) {
+
+ allData <- data.frame(confidence = numeric(), network = character())
+
+ for (ix in seq_along(currNetFiles)) {
+ currNetFile <- currNetFiles[ix] # Full path to edges_subset.txt
+ currNetName <- names(currNetFile)
+ if (file.exists(currNetFile)) {
+ data <- read.table(currNetFile, header = TRUE, sep = "\t")
+ # Compute absolute value of signedQuantile and create a temporary data frame.
+ conf <- data.frame(confidence = floor(data$Stability) / max(floor(data$Stability)))
+ tempDf <- data.frame(confidence = conf,
+ network = currNetName)
+ allData <- rbind(allData, tempDf)
+ } else {
+ message("File not found: ", currNetFile)
+ }
+ }
+
+ p <- ggplot(allData, aes(x = confidence, color = network)) +
+ # geom_freqpoly(bins = breaks[1], size = 1) +
+ geom_histogram(bins = breaks[1], alpha = 0.6, position = "identity", color = "black") +
+ labs(x = "Confidence", y = "Frequency") +
+ theme_bw() +
+ theme(
+ axis.title = element_text(size = 16, color = "black"),
+ axis.text = element_text(size = 14, color = "black")
+ )
+ histFile <- file.path(dirOut, paste0(saveName, ".pdf"))
+ ggsave(histFile, plot = p, width = 6, height = 5, dpi = 600)
+
+}
+
+
+histogramConfidencesStacked(netFiles, breaks = 300, dirOut = "/data/miraldiNB/Michael/mCD4T_Wayman/Figures", saveName = "confidencesTest")
\ No newline at end of file
diff --git a/evaluation/R/saveNormCountsArrowFIle.R b/evaluation/R/saveNormCountsArrowFIle.R
new file mode 100755
index 0000000..ddf647a
--- /dev/null
+++ b/evaluation/R/saveNormCountsArrowFIle.R
@@ -0,0 +1,35 @@
+rm(list=ls())
+options(stringsAsFactors=FALSE)
+set.seed(42)
+suppressPackageStartupMessages({
+ library(Signac)
+ library(Seurat)
+ library(arrow)
+})
+
+# make sure you load or have the latest version of curl installed.
+
+# make sure you load or have the latest version of curl installed.
+
+rds_path <- "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/annotation_scrna_final/obj_Tfh10_RNA_annotated.rds"
+obj <- readRDS(rds_path)
+
+# Process File for scGRN using Inferelator.
+obj <- NormalizeData(obj)
+norm_counts <- as.matrix(GetAssayData(obj, layer='data'))
+# write_parquet(as.data.frame(norm_counts), "/data/miraldiNB/Michael/GRN_Benchmark/Data/Tfh10_scRNA_logNorm_Counts.parquet")
+
+norm_counts <- as.data.frame(norm_counts)
+norm_counts$Genes <- rownames(norm_counts)
+norm_counts <- norm_counts[, c("Genes", setdiff(colnames(norm_counts), "Genes"))] # Reorder the columns to make 'Genes' the first column
+
+# Save as an arrow file
+write_feather(norm_counts, "/data/miraldiNB/Michael/GRN_Benchmark/Data/Tfh10_scRNA_logNorm_Counts.arrow")
+
+# write_feather(norm_counts, "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/scNormCounts.arrow")
+
+# feather_file <- read_feather("/data/miraldiNB/Katko/Projects/Barski_CD4_Multiome/Outs/Pseudobulk/RNA2/SC_counts.feather")
+# write_ipc_file(feather_file, "/data/miraldiNB/Katko/Projects/Barski_CD4_Multiome/Outs/Pseudobulk/RNA2/SC_counts.arrow")
+
+
+rds_path <- "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/ObjFiltered.rds"
\ No newline at end of file
diff --git a/examples/interactive_pipeline.jl b/examples/interactive_pipeline.jl
new file mode 100644
index 0000000..03fcfe0
--- /dev/null
+++ b/examples/interactive_pipeline.jl
@@ -0,0 +1,172 @@
+# =============================================================================
+# interactive_pipeline.jl — Step-by-step GRN inference using the public API
+#
+# What:
+# Runs the full 6-step mLASSO-StARS pipeline interactively. Each step calls
+# one high-level API function. Intermediate structs remain in the REPL between
+# steps so you can inspect, plot, or debug before continuing.
+#
+# Required inputs:
+# geneExprFile — gene expression matrix (.txt tab-delimited or .arrow)
+# targFile — target gene list (.txt, one gene per line)
+# regFile — potential regulator (TF) list (.txt, one TF per line)
+# priorFile — prior network matrix (.tsv, sparse TF × gene)
+# priorFilePenalties — prior(s) used to set LASSO penalties (same or different)
+#
+# Expected outputs (written under outputDir//):
+# TFA/edges.tsv — GRN inferred with TFA predictors
+# TFmRNA/edges.tsv — GRN inferred with TF mRNA predictors
+# Combined/combined_*.tsv — consensus network (max/mean/min aggregation)
+# Combined/TFA/ — refined TFA network re-estimated from consensus prior
+#
+# Usage:
+# Step-by-step in a Julia REPL (recommended for interactive analysis)
+# julia examples/interactive_pipeline.jl
+#
+# Compare with:
+# interactive_pipeline_dev.jl — same steps, module-qualified internal calls
+# run_pipeline.jl — same pipeline wrapped in a callable function
+#
+# Installation:
+# pkg> dev /path/to/InferelatorJL # local development
+# pkg> add "https://github.com/org/InferelatorJL.jl" # published release
+# Tip: load Revise before this file to pick up source edits without restarting:
+# using Revise; using InferelatorJL
+# =============================================================================
+using Revise
+using InferelatorJL
+
+# =============================================================================
+# Configuration — edit these paths and parameters for your dataset
+# =============================================================================
+
+out
+outputDir = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/Inferelator/test3"
+tfaOptions = ["", "TFmRNA"] # "" → TFA mode, "TFmRNA" → mRNA mode
+totSS = 80
+bstarsTotSS = 5
+subsampleFrac = 0.68
+minLambda = 0.01
+maxLambda = 0.5
+totLambdasBstars = 20
+totLambdas = 40
+targetInstability = 0.05
+meanEdgesPerGene = 20
+correlationWeight = 1
+minTargets = 3
+edgeSS = 0
+lambdaBias = [0.5]
+instabilityLevel = "Network" # "Network" or "Gene"
+useMeanEdgesPerGeneMode = true
+combineOpt = "max" # "max", "mean", or "min"
+zScoreTFA = true # z-score targets before TFA estimation
+zScoreLASSO = true # z-score targets before LASSO regression
+
+geneExprFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/pseudobulk/pseudobulk_scrna/CellType/Age/Factor1/min0.25M/counts_Tfh10_AgeCellType_pseudobulk_scrna_vst_batch_downsample_0.25M.txt"
+targFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/target_genes/gene_targ_Tfh10_SigPct5Log2FC0p58FDR5.txt"
+regFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/pot_regs/TF_Tfh10_SigPct5Log2FC0p58FDR5_final.txt"
+
+priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"
+priorFilePenalties = ["/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"]
+tfaGeneFile = "" # optional: restrict TFA estimation to a gene subset
+
+# --- Build output directory name (encodes key run parameters)
+subsamplePct = subsampleFrac * 100
+subsampleStr = isinteger(subsamplePct) ? string(Int(subsamplePct)) : replace(string(subsamplePct), "." => "p")
+lambdaStr = join(replace.(string.(lambdaBias), "." => "p"), "_")
+networkBaseName = lowercase(instabilityLevel) * "Lambda" * lambdaStr * "_" * string(totSS) * "totSS_" *
+ string(meanEdgesPerGene) * "tfsPerGene_" * "subsamplePCT" * subsampleStr
+dirOut = joinpath(outputDir, networkBaseName)
+mkpath(dirOut)
+
+@info "Configuration" outputDir=dirOut geneExprFile priorFile lambdaBias subsampleFrac
+
+# =============================================================================
+# STEP 1 — Load and filter expression data
+# =============================================================================
+# Loads expression matrix, filters to target genes and potential regulators.
+# Inspect: fieldnames(GeneExpressionData), size(data.expressionMat)
+
+data = loadData(geneExprFile, targFile, regFile;
+ tfaGeneFile = tfaGeneFile,
+ epsilon = 0.01)
+
+# =============================================================================
+# STEP 2 + 3 — Merge degenerate TFs, process prior, estimate TFA
+# =============================================================================
+# Merges TFs with identical binding profiles, builds the prior matrix,
+# and estimates TF activity (TFA) via least-squares.
+# Inspect: size(priorData.medTfas), priorData.tfNames
+
+priorData, mergedTFs = loadPrior(data, priorFile; minTargets = minTargets)
+
+estimateTFA(priorData, data;
+ edgeSS = edgeSS,
+ zScoreTFA = zScoreTFA,
+ outputDir = dirOut)
+
+# =============================================================================
+# STEP 4 — Build GRN for each predictor mode
+# =============================================================================
+# Runs mLASSO-StARS for TFA predictors and TF mRNA predictors separately.
+# Outputs instability curves and ranked edge lists to TFA/ and TFmRNA/.
+
+for (tfaMode, modeLabel) in [(true, "TFA"), (false, "TFmRNA")]
+ instabilitiesDir = joinpath(dirOut, modeLabel)
+ mkpath(instabilitiesDir)
+
+ @info "Building network" mode=modeLabel
+
+ buildNetwork(data, priorData;
+ tfaMode = tfaMode,
+ priorFilePenalties = priorFilePenalties,
+ lambdaBias = lambdaBias,
+ totSS = totSS,
+ bstarsTotSS = bstarsTotSS,
+ subsampleFrac = subsampleFrac,
+ minLambda = minLambda,
+ maxLambda = maxLambda,
+ totLambdasBstars = totLambdasBstars,
+ totLambdas = totLambdas,
+ targetInstability = targetInstability,
+ meanEdgesPerGene = meanEdgesPerGene,
+ correlationWeight = correlationWeight,
+ instabilityLevel = instabilityLevel,
+ useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode,
+ zScoreLASSO = zScoreLASSO,
+ outputDir = instabilitiesDir)
+end
+
+# =============================================================================
+# STEP 5 — Aggregate TFA + mRNA networks into a consensus network
+# =============================================================================
+# Combines the two edge lists using max/mean/min stability per (TF, gene) pair.
+# Outputs combined_.tsv and combined__sp.tsv to Combined/.
+
+combinedNetDir = joinpath(dirOut, "Combined")
+aggregateNetworks(
+ [joinpath(dirOut, "TFA", "edges.tsv"),
+ joinpath(dirOut, "TFmRNA", "edges.tsv")];
+ method = Symbol(combineOpt),
+ meanEdgesPerGene = meanEdgesPerGene,
+ useMeanEdgesPerGene = useMeanEdgesPerGeneMode,
+ outputDir = combinedNetDir)
+
+# =============================================================================
+# STEP 6 — Re-estimate TFA using the consensus network as a refined prior
+# =============================================================================
+# Uses the combined network as a new prior to re-estimate TF activity, then
+# re-runs mLASSO-StARS. Outputs go to Combined/TFA/.
+
+netsCombinedSparse = joinpath(combinedNetDir, "combined_" * combineOpt * "_sp.tsv")
+refineTFA(netsCombinedSparse, data, mergedTFs;
+ tfaGeneFile = tfaGeneFile,
+ edgeSS = edgeSS,
+ minTargets = minTargets,
+ zScoreTFA = zScoreTFA,
+ exprFile = geneExprFile,
+ targFile = targFile,
+ regFile = regFile,
+ outputDir = combinedNetDir)
+
+@info "Pipeline complete" outputDir=dirOut
diff --git a/examples/interactive_pipeline_dev.jl b/examples/interactive_pipeline_dev.jl
new file mode 100644
index 0000000..f3e6fc7
--- /dev/null
+++ b/examples/interactive_pipeline_dev.jl
@@ -0,0 +1,199 @@
+# =============================================================================
+# interactive_pipeline_dev.jl — Step-by-step GRN inference via internal functions
+#
+# What:
+# Identical pipeline to interactive_pipeline.jl but calls internal functions
+# directly using the InferelatorJL. module prefix. Use this when you need
+# finer control over individual steps (e.g., inspecting intermediate matrices,
+# swapping in a custom subfunction, or debugging a specific stage).
+# All internal functions remain accessible via module-qualified calls even
+# though they are not exported from the package.
+#
+# Required inputs:
+# geneExprFile — gene expression matrix (.txt tab-delimited or .arrow)
+# targFile — target gene list (.txt, one gene per line)
+# regFile — potential regulator (TF) list (.txt, one TF per line)
+# priorFile — prior network matrix (.tsv, sparse TF × gene)
+# priorFilePenalties — prior(s) used to set LASSO penalties (same or different)
+#
+# Expected outputs (written under outputDir//):
+# TFA/edges.tsv — GRN inferred with TFA predictors
+# TFmRNA/edges.tsv — GRN inferred with TF mRNA predictors
+# Combined/combined_*.tsv — consensus network (max/mean/min aggregation)
+# Combined/TFA/ — refined TFA network re-estimated from consensus prior
+#
+# Usage:
+# Step-by-step in a Julia REPL (recommended for inspecting intermediate state)
+# julia examples/interactive_pipeline_dev.jl
+#
+# Compare with:
+# interactive_pipeline.jl — same steps via high-level public API
+# run_pipeline_dev.jl — same internal calls wrapped in a callable function
+#
+# Installation:
+# pkg> dev /path/to/InferelatorJL # local development
+# pkg> add "https://github.com/org/InferelatorJL.jl" # published release
+# Tip: load Revise before this file to pick up source edits without restarting:
+# using Revise; using InferelatorJL
+# =============================================================================
+using Revise
+using InferelatorJL
+
+# =============================================================================
+# Configuration — edit these paths and parameters for your dataset
+# =============================================================================
+
+outputDir = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/Inferelator/test"
+tfaOptions = ["", "TFmRNA"] # "" → TFA mode, "TFmRNA" → mRNA mode
+totSS = 80
+bstarsTotSS = 5
+subsampleFrac = 0.68
+minLambda = 0.01
+maxLambda = 0.5
+totLambdasBstars = 20
+totLambdas = 40
+targetInstability = 0.05
+meanEdgesPerGene = 20
+correlationWeight = 1
+minTargets = 3
+edgeSS = 0
+lambdaBias = [0.5]
+instabilityLevel = "Network" # "Network" or "Gene"
+useMeanEdgesPerGeneMode = true
+combineOpt = "max" # "max", "mean", or "min"
+zScoreTFA = true # z-score targets before TFA estimation
+zScoreLASSO = true # z-score targets before LASSO regression
+
+geneExprFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/pseudobulk/pseudobulk_scrna/CellType/Age/Factor1/min0.25M/counts_Tfh10_AgeCellType_pseudobulk_scrna_vst_batch_downsample_0.25M.txt"
+targFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/target_genes/gene_targ_Tfh10_SigPct5Log2FC0p58FDR5.txt"
+regFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/pot_regs/TF_Tfh10_SigPct5Log2FC0p58FDR5_final.txt"
+
+priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"
+priorFilePenalties = ["/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"]
+tfaGeneFile = "" # optional: restrict TFA estimation to a gene subset
+
+# --- Build output directory name (encodes key run parameters)
+subsamplePct = subsampleFrac * 100
+subsampleStr = isinteger(subsamplePct) ? string(Int(subsamplePct)) : replace(string(subsamplePct), "." => "p")
+lambdaStr = join(replace.(string.(lambdaBias), "." => "p"), "_")
+networkBaseName = lowercase(instabilityLevel) * "Lambda" * lambdaStr * "_" * string(totSS) * "totSS_" *
+ string(meanEdgesPerGene) * "tfsPerGene_" * "subsamplePCT" * subsampleStr
+dirOut = joinpath(outputDir, networkBaseName)
+mkpath(dirOut)
+
+@info "Configuration" outputDir=dirOut geneExprFile priorFile lambdaBias subsampleFrac
+
+# =============================================================================
+# STEP 1 — Load and filter expression data
+# =============================================================================
+# Loads expression matrix, filters to target genes and potential regulators.
+# Inspect: fieldnames(GeneExpressionData), size(data.expressionMat), data.geneNames
+
+data = GeneExpressionData()
+InferelatorJL.loadExpressionData!(data, geneExprFile)
+InferelatorJL.loadAndFilterTargetGenes!(data, targFile; epsilon = 0.01)
+InferelatorJL.loadPotentialRegulators!(data, regFile)
+InferelatorJL.processTFAGenes!(data, tfaGeneFile; outputDir = dirOut)
+
+# =============================================================================
+# STEP 2 — Merge degenerate TFs
+# =============================================================================
+# Identifies TFs with identical binding profiles in the prior and merges them.
+# Inspect: mergedTFsData.mergedTFs, mergedTFsData.tfNames
+
+mergedTFsData = mergedTFsResult()
+InferelatorJL.mergeDegenerateTFs(mergedTFsData, priorFile; fileFormat = 2)
+
+# =============================================================================
+# STEP 3 — Process prior and estimate TFA
+# =============================================================================
+# Builds the filtered prior matrix and estimates TF activity via least-squares.
+# Inspect: tfaData.priorMat, tfaData.medTfas, size(tfaData.medTfas)
+
+tfaData = PriorTFAData()
+InferelatorJL.processPriorFile!(tfaData, data, priorFile;
+ mergedTFsData = mergedTFsData,
+ minTargets = minTargets)
+InferelatorJL.calculateTFA!(tfaData, data;
+ edgeSS = edgeSS,
+ zTarget = zScoreTFA,
+ outputDir = dirOut)
+
+# =============================================================================
+# STEP 4 — Build GRN for each predictor mode
+# =============================================================================
+# Runs subsampling, warm-start lambda selection, instability estimation,
+# lambda selection, and edge ranking for each predictor mode.
+
+for tfaOpt in tfaOptions
+ instabilitiesDir = tfaOpt == "" ? joinpath(dirOut, "TFA") : joinpath(dirOut, "TFmRNA")
+ mkpath(instabilitiesDir)
+
+ @info "Building network" tfaOpt=(isempty(tfaOpt) ? "TFA" : tfaOpt)
+
+ grnData = GrnData()
+ InferelatorJL.preparePredictorMat!(grnData, data, tfaData; tfaOpt = tfaOpt)
+ InferelatorJL.preparePenaltyMatrix!(data, grnData;
+ priorFilePenalties = priorFilePenalties,
+ lambdaBias = lambdaBias,
+ tfaOpt = tfaOpt)
+ InferelatorJL.constructSubsamples(data, grnData; totSS = bstarsTotSS, subsampleFrac = subsampleFrac)
+ InferelatorJL.bstarsWarmStart(data, tfaData, grnData;
+ minLambda = minLambda,
+ maxLambda = maxLambda,
+ totLambdasBstars = totLambdasBstars,
+ targetInstability = targetInstability,
+ zTarget = zScoreLASSO)
+ InferelatorJL.constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = subsampleFrac)
+ InferelatorJL.bstartsEstimateInstability(grnData;
+ totLambdas = totLambdas,
+ instabilityLevel = instabilityLevel,
+ zTarget = zScoreLASSO,
+ outputDir = instabilitiesDir)
+
+ buildGrn = BuildGrn()
+ InferelatorJL.chooseLambda!(grnData, buildGrn;
+ instabilityLevel = instabilityLevel,
+ targetInstability = targetInstability)
+ InferelatorJL.rankEdges!(data, tfaData, grnData, buildGrn;
+ useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode,
+ meanEdgesPerGene = meanEdgesPerGene,
+ correlationWeight = correlationWeight,
+ outputDir = instabilitiesDir)
+ writeNetworkTable!(buildGrn; outputDir = instabilitiesDir)
+end
+
+# =============================================================================
+# STEP 5 — Aggregate TFA + mRNA networks into a consensus network
+# =============================================================================
+# Combines edge lists using max/mean/min stability per (TF, gene) pair.
+
+combinedNetDir = joinpath(dirOut, "Combined")
+nets2combine = [
+ joinpath(dirOut, "TFA", "edges.tsv"),
+ joinpath(dirOut, "TFmRNA", "edges.tsv")
+]
+InferelatorJL.aggregateNetworks(nets2combine;
+ method = Symbol(combineOpt),
+ meanEdgesPerGene = meanEdgesPerGene,
+ useMeanEdgesPerGene = useMeanEdgesPerGeneMode,
+ outputDir = combinedNetDir)
+
+# =============================================================================
+# STEP 6 — Re-estimate TFA using the consensus network as a refined prior
+# =============================================================================
+# Uses combined network as new prior, re-estimates TFA, re-runs mLASSO-StARS.
+
+netsCombinedSparse = joinpath(combinedNetDir, "combined_" * combineOpt * "_sp.tsv")
+InferelatorJL.refineTFA(data, mergedTFsData;
+ priorFile = netsCombinedSparse,
+ tfaGeneFile = tfaGeneFile,
+ edgeSS = edgeSS,
+ minTargets = minTargets,
+ zTarget = zScoreTFA,
+ geneExprFile = geneExprFile,
+ targFile = targFile,
+ regFile = regFile,
+ outputDir = combinedNetDir)
+
+@info "Pipeline complete" outputDir=dirOut
diff --git a/examples/plotPR.jl b/examples/plotPR.jl
new file mode 100644
index 0000000..5d8909a
--- /dev/null
+++ b/examples/plotPR.jl
@@ -0,0 +1,221 @@
+# =============================================================================
+# plotPR.jl — Evaluate GRN predictions against gold standards and plot PR curves
+#
+# What:
+# Evaluates one or more inferred GRNs against gold-standard interaction sets
+# by computing precision-recall (PR) and ROC metrics, then generating
+# publication-quality PR curve plots. Supports multiple networks and
+# multiple gold standards in a single run. Optionally generates per-TF
+# PR curves and AUPR bar plots.
+#
+# Required inputs:
+# outNetFiles — inferred GRN file(s) as legend label → file path dict
+# gsParam — gold-standard file(s) as name → file path dict
+# prTargGeneFile — target gene list used to restrict evaluation universe
+# (set to "" to use all genes in the network)
+# gsRegsFile — regulator list to restrict evaluation to shared TFs
+# (set to "" to use all regulators)
+#
+# Expected outputs (written relative to each network file's directory):
+# PR_noPotRegs// — PR data files (if gsRegsFile = "")
+# PR_withPotRegs// — PR data files (if gsRegsFile is set)
+# dirOutPlot/_*.png — PR curve plots and optional AUPR bar plots
+#
+# Usage:
+# julia examples/plotPR.jl
+# or configure the USER CONFIG section and run step-by-step in the REPL
+#
+# Installation:
+# pkg> dev /path/to/InferelatorJL # local development
+# pkg> add "https://github.com/org/InferelatorJL.jl" # published release
+# Tip: load Revise before this file to pick up source edits without restarting:
+# using Revise; using InferelatorJL
+# =============================================================================
+
+using InferelatorJL
+import InferelatorJL: computePR, plotPRCurves, plotAUPR, loadPRData
+
+using OrderedCollections
+
+# ══════════════════════════════════════════════════════════════════════════════
+# USER CONFIG — edit this section
+# ══════════════════════════════════════════════════════════════════════════════
+
+# Output directory for plots
+dirOutPlot = "/data/miraldiNB/Michael/projects/goldStandard/Human/fdr0p01_vs_IDR"
+
+# Base name for saved figures (set to "" to use gold-standard name only)
+figBaseName = "ALL_5KB_sharedTF_Target"
+
+# Network files to compare: legend label => file path
+outNetFiles = OrderedDict(
+ "FDR0p01" => "/data/miraldiNB/Michael/projects/goldStandard/Human/GS_FDR0p01/20260325/GS_5KB_TSS/sumRegion_reducePool/All/context_mean/global_max/GS_All_peakScore10_top50_sharedTF_Gene_withIDR.tsv",
+ "IDR" => "/data/miraldiNB/Michael/projects/goldStandard/Human/GS_IDR/20260321/GS_5KB_TSS/sumRegion_reducePool/All/context_mean/global_max/GS_All_peakScore10_top50_sharedTF_Gene_withFDR0p01.tsv"
+)
+
+# Gold-standard files: name => file path
+gsParam = OrderedDict(
+ "KO_GS" => "/data/miraldiNB/Michael/projects/GRN/hCD4T_Katko/dataBank/GS/KO_GS_50_Michael_autosomal.tsv",
+)
+
+# Evaluation inputs
+prTargGeneFile = "/data/miraldiNB/Michael/projects/GRN/hCD4T_Katko/dataBank/potTargRegs/Targs/all_targs_autosomal.txt"
+gsRegsFile = "/data/miraldiNB/Katko/Projects/Barski_CD4_Multiome/Outs/Prior/SubsetPriors/all_TFs.txt"
+rankColTrn = 3 # column in GRN file corresponding to interaction ranks/confidences
+breakTies = true
+auprLimit = 0.1
+
+# Plot parameters
+lineTypes = [] # e.g. ["-", "--", "-."] — one per dataset; [] uses defaults
+lineWidths = [] # per dataset; [] uses defaults
+lineColors = [] # per dataset; [] uses defaults
+xLimitRecall = 0.1
+yStepSize = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] # one per gold standard
+yScaleType = "linear"
+yZoomPR = [[0.4, 0.9], [0.2, 0.9], [0.3, 0.9], [], [], [], [], [], [], [], [0.7, 0.9], [], [0.7, 0.9]] # one per gold standard, or []
+heightRatios = [0.5, 3.0] # height ratios for broken y-axis panels
+isInside = false # legend inside the plot
+plotAUPRflag = false # set to true to also generate AUPR bar plots
+combinePlot = true # generate a combined PR curve per network/GS pair
+doPerTF = true # compute per-TF PR metrics
+tfList = [] # list of TF names for per-TF curves; [] skips Section 3
+
+# ══════════════════════════════════════════════════════════════════════════════
+# EXECUTION — no edits needed below this line
+# ══════════════════════════════════════════════════════════════════════════════
+
+mkpath(dirOutPlot)
+
+# --- Helper: resolve per-GS plot parameters ---
+function getPlotParams(i, gsName; figBaseName, yZoomPR, yStepSize)
+ saveNamePR = isempty(figBaseName) ? "$(gsName)" : "$(figBaseName)_$(gsName)"
+ currentYzoomPR = (length(yZoomPR) >= i && !isempty(yZoomPR[i])) ? yZoomPR[i] : Float64[]
+ currentYstepSize = (length(yStepSize) >= i && !isempty(yStepSize[i])) ? yStepSize[i] : nothing
+ return saveNamePR, currentYzoomPR, currentYstepSize
+end
+
+# ── 1. Calculate PR/ROC metrics ───────────────────────────────────────────────
+@info "---- 1. Calculating Performance Metrics for the Networks -----"
+prFilesByGS = OrderedDict{String, OrderedDict{String, Any}}()
+
+for (legendLabel, outNetFile) in outNetFiles
+ @info "Processing network" network=legendLabel file=outNetFile
+ filepath = dirname(outNetFile)
+
+ for (gsName, gsFile) in gsParam
+ dirOut = isempty(gsRegsFile) ? joinpath(filepath, "PR_noPotRegs", gsName) :
+ joinpath(filepath, "PR_withPotRegs", gsName)
+ mkpath(dirOut)
+ @info "Using GS" gs=gsName saveDir=dirOut
+
+ res = computePR(gsFile, outNetFile;
+ gsRegsFile = gsRegsFile,
+ targGeneFile = prTargGeneFile,
+ rankColTrn = rankColTrn,
+ breakTies = breakTies,
+ partialAUPRlimit = auprLimit,
+ doPerTF = doPerTF,
+ saveDir = dirOut)
+
+ if !haskey(prFilesByGS, gsName)
+ prFilesByGS[gsName] = OrderedDict{String, Any}()
+ end
+ prFilesByGS[gsName][legendLabel] = haskey(res, :savedFile) ? res[:savedFile] : res
+ end
+end
+
+# ── 2. Global PR curves ───────────────────────────────────────────────────────
+@info "---- 2. Generating Global PR Curves ----"
+if combinePlot
+ for (i, (gsName, listFilePR)) in enumerate(prFilesByGS)
+ @info "Plotting PR curves" gs=gsName
+
+ saveNamePR, currentYzoomPR, currentYstepSize = getPlotParams(i, gsName;
+ figBaseName = figBaseName,
+ yZoomPR = yZoomPR,
+ yStepSize = yStepSize)
+
+ plotPRCurves(listFilePR, dirOutPlot, saveNamePR;
+ xLimitRecall = xLimitRecall,
+ yZoomPR = currentYzoomPR,
+ yStepSize = currentYstepSize,
+ yScale = yScaleType,
+ isInside = isInside,
+ lineColors = lineColors,
+ lineTypes = lineTypes,
+ lineWidths = lineWidths,
+ heightRatios = heightRatios,
+ mode = :global)
+
+ if plotAUPRflag
+ singleGS = OrderedDict(gsName => listFilePR)
+ saveNameAUPR = isempty(figBaseName) ? "$(gsName)" : "$(figBaseName)_$(gsName)"
+
+ for (figSize, saveLegend) in [((5, 4), true), ((1.5, 1.5), false)]
+ plotAUPR(singleGS, dirOutPlot;
+ saveName = saveNameAUPR,
+ metricType = "partial",
+ figSize = figSize,
+ axisTitleSize = 9,
+ tickLabelSize = 7,
+ legendFontSize = 9,
+ tickRotation = 45,
+ plotType = "bar",
+ saveLegend = saveLegend)
+ end
+ end
+ @info "Plots completed" gs=gsName
+ end
+end
+
+# ── 3. Per-TF PR curves ───────────────────────────────────────────────────────
+if !isempty(tfList)
+ @info "----- 3. Generating Per-TF PR Curves -----"
+ for (i, (gsName, resultsDict)) in enumerate(prFilesByGS)
+ @info "Plotting per-TF PR curves" gs=gsName
+
+ saveNamePR, currentYzoomPR, currentYstepSize = getPlotParams(i, gsName;
+ figBaseName = figBaseName,
+ yZoomPR = yZoomPR,
+ yStepSize = yStepSize)
+ saveNamePR = "perTF_$(saveNamePR)"
+ tfListPR = OrderedDict()
+
+ resCache = Dict{String, Any}()
+ for (runName, source) in resultsDict
+ resCache[runName] = loadPRData(source; mode = :perTF)
+ end
+
+ for (runName, res) in resCache
+ res === nothing && continue
+ tfIndex = Dict(tf => j for (j, tf) in enumerate(res[:gsRegs]))
+ for tf in tfList
+ idx = get(tfIndex, tf, nothing)
+ idx === nothing && continue
+
+ label =
+ length(tfList) == 1 && length(resultsDict) > 1 ? runName :
+ length(resultsDict) == 1 && length(tfList) > 1 ? tf :
+ "$runName - $tf"
+
+ tfListPR[label] = Dict(
+ :precisions => res[:precisions][idx],
+ :recalls => res[:recalls][idx],
+ :randPR => res[:randPR][idx]
+ )
+ end
+ end
+
+ plotPRCurves(tfListPR, dirOutPlot, saveNamePR;
+ xLimitRecall = xLimitRecall,
+ yZoomPR = currentYzoomPR,
+ yStepSize = currentYstepSize,
+ yScale = yScaleType,
+ isInside = isInside,
+ lineColors = lineColors,
+ lineTypes = lineTypes,
+ lineWidths = lineWidths)
+ end
+end
+
+@info "Completed — plots generated for all gold standards"
diff --git a/examples/run_pipeline.jl b/examples/run_pipeline.jl
new file mode 100644
index 0000000..0ce0342
--- /dev/null
+++ b/examples/run_pipeline.jl
@@ -0,0 +1,160 @@
+# =============================================================================
+# run_pipeline.jl — Batch GRN inference wrapped in a callable function (public API)
+#
+# What:
+# Wraps the full 6-step mLASSO-StARS pipeline in runInferelator(), which
+# accepts all parameters as keyword arguments with sensible defaults.
+# Use this for batch execution, scripted cluster jobs, or when running
+# multiple parameter sweeps programmatically.
+# All steps call high-level public API functions (no internal functions exposed).
+#
+# Required inputs:
+# geneExprFile — gene expression matrix (.txt tab-delimited or .arrow)
+# targFile — target gene list (.txt, one gene per line)
+# regFile — potential regulator (TF) list (.txt, one TF per line)
+# priorFile — prior network matrix (.tsv, sparse TF × gene)
+# priorFilePenalties — prior(s) used to set LASSO penalties (same or different)
+# outputDir — root directory for all outputs
+#
+# Expected outputs (written under outputDir//):
+# TFA/edges.tsv — GRN inferred with TFA predictors
+# TFmRNA/edges.tsv — GRN inferred with TF mRNA predictors
+# Combined/combined_*.tsv — consensus network (max/mean/min aggregation)
+# Combined/TFA/ — refined TFA network re-estimated from consensus prior
+#
+# Usage:
+# julia examples/run_pipeline.jl
+# or call runInferelator() from another script after: using InferelatorJL
+#
+# Compare with:
+# interactive_pipeline.jl — same steps run interactively (not wrapped)
+# run_pipeline_dev.jl — same function using module-qualified internal calls
+#
+# Installation:
+# pkg> dev /path/to/InferelatorJL # local development
+# pkg> add "https://github.com/org/InferelatorJL.jl" # published release
+# Tip: load Revise before this file to pick up source edits without restarting:
+# using Revise; using InferelatorJL
+# =============================================================================
+
+using InferelatorJL
+
+# =============================================================================
+# Pipeline function
+# =============================================================================
+
+function runInferelator(;
+ geneExprFile::String,
+ targFile::String,
+ regFile::String,
+ priorFile::String,
+ priorFilePenalties::Vector{String},
+ tfaGeneFile::String = "",
+ outputDir::String,
+ totSS::Int = 80,
+ bstarsTotSS::Int = 5,
+ subsampleFrac::Float64 = 0.68,
+ minLambda::Float64 = 0.01,
+ maxLambda::Float64 = 0.5,
+ totLambdasBstars::Int = 20,
+ totLambdas::Int = 40,
+ targetInstability::Float64 = 0.05,
+ meanEdgesPerGene::Int = 20,
+ correlationWeight::Int = 1,
+ minTargets::Int = 3,
+ edgeSS::Int = 0,
+ lambdaBias::Vector{Float64} = [0.5],
+ instabilityLevel::String = "Network", # "Network" or "Gene"
+ useMeanEdgesPerGeneMode::Bool = true,
+ combineOpt::String = "max", # "max", "mean", or "min"
+ zScoreTFA::Bool = true,
+ zScoreLASSO::Bool = true
+)
+ # Build output directory name (encodes key run parameters)
+ subsamplePct = subsampleFrac * 100
+ subsampleStr = isinteger(subsamplePct) ? string(Int(subsamplePct)) : replace(string(subsamplePct), "." => "p")
+ lambdaStr = join(replace.(string.(lambdaBias), "." => "p"), "_")
+ networkBaseName = lowercase(instabilityLevel) * "Lambda" * lambdaStr * "_" * string(totSS) * "totSS_" *
+ string(meanEdgesPerGene) * "tfsPerGene_" * "subsamplePCT" * subsampleStr
+ dirOut = joinpath(outputDir, networkBaseName)
+ mkpath(dirOut)
+
+ @info "Starting pipeline" outputDir=dirOut geneExprFile priorFile lambdaBias subsampleFrac
+
+ # Step 1 — Load and filter expression data
+ data = loadData(geneExprFile, targFile, regFile;
+ tfaGeneFile = tfaGeneFile,
+ epsilon = 0.01)
+
+ # Steps 2 + 3 — Merge degenerate TFs, process prior, estimate TFA
+ priorData, mergedTFs = loadPrior(data, priorFile; minTargets = minTargets)
+
+ estimateTFA(priorData, data;
+ edgeSS = edgeSS,
+ zScoreTFA = zScoreTFA,
+ outputDir = dirOut)
+
+ # Step 4 — Build GRN for each predictor mode
+ for (tfaMode, modeLabel) in [(true, "TFA"), (false, "TFmRNA")]
+ instabilitiesDir = joinpath(dirOut, modeLabel)
+ mkpath(instabilitiesDir)
+
+ @info "Building network" mode=modeLabel
+
+ buildNetwork(data, priorData;
+ tfaMode = tfaMode,
+ priorFilePenalties = priorFilePenalties,
+ lambdaBias = lambdaBias,
+ totSS = totSS,
+ bstarsTotSS = bstarsTotSS,
+ subsampleFrac = subsampleFrac,
+ minLambda = minLambda,
+ maxLambda = maxLambda,
+ totLambdasBstars = totLambdasBstars,
+ totLambdas = totLambdas,
+ targetInstability = targetInstability,
+ meanEdgesPerGene = meanEdgesPerGene,
+ correlationWeight = correlationWeight,
+ instabilityLevel = instabilityLevel,
+ useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode,
+ zScoreLASSO = zScoreLASSO,
+ outputDir = instabilitiesDir)
+ end
+
+ # Step 5 — Aggregate TFA + mRNA networks into a consensus network
+ combinedNetDir = joinpath(dirOut, "Combined")
+ aggregateNetworks(
+ [joinpath(dirOut, "TFA", "edges.tsv"),
+ joinpath(dirOut, "TFmRNA", "edges.tsv")];
+ method = Symbol(combineOpt),
+ meanEdgesPerGene = meanEdgesPerGene,
+ useMeanEdgesPerGene = useMeanEdgesPerGeneMode,
+ outputDir = combinedNetDir)
+
+ # Step 6 — Re-estimate TFA using the consensus network as a refined prior
+ netsCombinedSparse = joinpath(combinedNetDir, "combined_" * combineOpt * "_sp.tsv")
+ refineTFA(netsCombinedSparse, data, mergedTFs;
+ tfaGeneFile = tfaGeneFile,
+ edgeSS = edgeSS,
+ minTargets = minTargets,
+ zScoreTFA = zScoreTFA,
+ exprFile = geneExprFile,
+ targFile = targFile,
+ regFile = regFile,
+ outputDir = combinedNetDir)
+
+ @info "Pipeline complete" outputDir=dirOut
+end
+
+# =============================================================================
+# Run — replace paths with your own data
+# =============================================================================
+
+runInferelator(
+ geneExprFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/pseudobulk/pseudobulk_scrna/CellType/Age/Factor1/min0.25M/counts_Tfh10_AgeCellType_pseudobulk_scrna_vst_batch_NoState.txt",
+ targFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/target_genes/gene_targ_Tfh10_SigPct5Log2FC0p58FDR5.txt",
+ regFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/pot_regs/TF_Tfh10_SigPct5Log2FC0p58FDR5_final.txt",
+ priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv",
+ priorFilePenalties = ["/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"],
+ outputDir = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/Inferelator/test"
+)
diff --git a/examples/run_pipeline_dev.jl b/examples/run_pipeline_dev.jl
new file mode 100644
index 0000000..af51089
--- /dev/null
+++ b/examples/run_pipeline_dev.jl
@@ -0,0 +1,184 @@
+# =============================================================================
+# run_pipeline_dev.jl — Batch GRN inference via module-qualified internal calls
+#
+# What:
+# Identical pipeline to run_pipeline.jl but calls internal functions directly
+# using the InferelatorJL. module prefix. Use this when you need finer control
+# over individual pipeline steps while still running in batch/function mode.
+# All internal functions remain accessible via module-qualified calls even
+# though they are not exported from the package.
+#
+# Required inputs:
+# geneExprFile — gene expression matrix (.txt tab-delimited or .arrow)
+# targFile — target gene list (.txt, one gene per line)
+# regFile — potential regulator (TF) list (.txt, one TF per line)
+# priorFile — prior network matrix (.tsv, sparse TF × gene)
+# priorFilePenalties — prior(s) used to set LASSO penalties (same or different)
+# outputDir — root directory for all outputs
+#
+# Expected outputs (written under outputDir//):
+# TFA/edges.tsv — GRN inferred with TFA predictors
+# TFmRNA/edges.tsv — GRN inferred with TF mRNA predictors
+# Combined/combined_*.tsv — consensus network (max/mean/min aggregation)
+# Combined/TFA/ — refined TFA network re-estimated from consensus prior
+#
+# Usage:
+# julia examples/run_pipeline_dev.jl
+# or call runInferelator() from another script after: using InferelatorJL
+#
+# Compare with:
+# interactive_pipeline_dev.jl — same internal calls run interactively
+# run_pipeline.jl — same function using public API
+#
+# Installation:
+# pkg> dev /path/to/InferelatorJL # local development
+# pkg> add "https://github.com/org/InferelatorJL.jl" # published release
+# Tip: load Revise before this file to pick up source edits without restarting:
+# using Revise; using InferelatorJL
+# =============================================================================
+
+using InferelatorJL
+
+# =============================================================================
+# Pipeline function
+# =============================================================================
+
+function runInferelator(;
+ geneExprFile::String,
+ targFile::String,
+ regFile::String,
+ priorFile::String,
+ priorFilePenalties::Vector{String},
+ tfaGeneFile::String = "",
+ outputDir::String,
+ tfaOptions::Vector{String} = ["", "TFmRNA"], # "" → TFA, "TFmRNA" → mRNA
+ totSS::Int = 80,
+ bstarsTotSS::Int = 5,
+ subsampleFrac::Float64 = 0.68,
+ minLambda::Float64 = 0.01,
+ maxLambda::Float64 = 0.5,
+ totLambdasBstars::Int = 20,
+ totLambdas::Int = 40,
+ targetInstability::Float64 = 0.05,
+ meanEdgesPerGene::Int = 20,
+ correlationWeight::Int = 1,
+ minTargets::Int = 3,
+ edgeSS::Int = 0,
+ lambdaBias::Vector{Float64} = [0.5],
+ instabilityLevel::String = "Network", # "Network" or "Gene"
+ useMeanEdgesPerGeneMode::Bool = true,
+ combineOpt::String = "max", # "max", "mean", or "min"
+ zScoreTFA::Bool = true,
+ zScoreLASSO::Bool = true
+)
+ # Build output directory name (encodes key run parameters)
+ subsamplePct = subsampleFrac * 100
+ subsampleStr = isinteger(subsamplePct) ? string(Int(subsamplePct)) : replace(string(subsamplePct), "." => "p")
+ lambdaStr = join(replace.(string.(lambdaBias), "." => "p"), "_")
+ networkBaseName = lowercase(instabilityLevel) * "Lambda" * lambdaStr * "_" * string(totSS) * "totSS_" *
+ string(meanEdgesPerGene) * "tfsPerGene_" * "subsamplePCT" * subsampleStr
+ dirOut = joinpath(outputDir, networkBaseName)
+ mkpath(dirOut)
+
+ @info "Starting pipeline" outputDir=dirOut geneExprFile priorFile lambdaBias subsampleFrac
+
+ # Step 1 — Load and filter expression data
+ data = GeneExpressionData()
+ InferelatorJL.loadExpressionData!(data, geneExprFile)
+ InferelatorJL.loadAndFilterTargetGenes!(data, targFile; epsilon = 0.01)
+ InferelatorJL.loadPotentialRegulators!(data, regFile)
+ InferelatorJL.processTFAGenes!(data, tfaGeneFile; outputDir = dirOut)
+
+ # Step 2 — Merge degenerate TFs
+ mergedTFsData = mergedTFsResult()
+ InferelatorJL.mergeDegenerateTFs(mergedTFsData, priorFile; fileFormat = 2)
+
+ # Step 3 — Process prior and estimate TFA
+ tfaData = PriorTFAData()
+ InferelatorJL.processPriorFile!(tfaData, data, priorFile;
+ mergedTFsData = mergedTFsData,
+ minTargets = minTargets)
+ InferelatorJL.calculateTFA!(tfaData, data;
+ edgeSS = edgeSS,
+ zTarget = zScoreTFA,
+ outputDir = dirOut)
+
+ # Step 4 — Build GRN for each predictor mode
+ for tfaOpt in tfaOptions
+ instabilitiesDir = tfaOpt == "" ? joinpath(dirOut, "TFA") : joinpath(dirOut, "TFmRNA")
+ mkpath(instabilitiesDir)
+
+ @info "Building network" tfaOpt=(isempty(tfaOpt) ? "TFA" : tfaOpt)
+
+ grnData = GrnData()
+ InferelatorJL.preparePredictorMat!(grnData, data, tfaData; tfaOpt = tfaOpt)
+ InferelatorJL.preparePenaltyMatrix!(data, grnData;
+ priorFilePenalties = priorFilePenalties,
+ lambdaBias = lambdaBias,
+ tfaOpt = tfaOpt)
+ InferelatorJL.constructSubsamples(data, grnData; totSS = bstarsTotSS, subsampleFrac = subsampleFrac)
+ InferelatorJL.bstarsWarmStart(data, tfaData, grnData;
+ minLambda = minLambda,
+ maxLambda = maxLambda,
+ totLambdasBstars = totLambdasBstars,
+ targetInstability = targetInstability,
+ zTarget = zScoreLASSO)
+ InferelatorJL.constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = subsampleFrac)
+ InferelatorJL.bstartsEstimateInstability(grnData;
+ totLambdas = totLambdas,
+ instabilityLevel = instabilityLevel,
+ zTarget = zScoreLASSO,
+ outputDir = instabilitiesDir)
+
+ buildGrn = BuildGrn()
+ InferelatorJL.chooseLambda!(grnData, buildGrn;
+ instabilityLevel = instabilityLevel,
+ targetInstability = targetInstability)
+ InferelatorJL.rankEdges!(data, tfaData, grnData, buildGrn;
+ useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode,
+ meanEdgesPerGene = meanEdgesPerGene,
+ correlationWeight = correlationWeight,
+ outputDir = instabilitiesDir)
+ writeNetworkTable!(buildGrn; outputDir = instabilitiesDir)
+ end
+
+ # Step 5 — Aggregate TFA + mRNA networks into a consensus network
+ combinedNetDir = joinpath(dirOut, "Combined")
+ nets2combine = [
+ joinpath(dirOut, "TFA", "edges.tsv"),
+ joinpath(dirOut, "TFmRNA", "edges.tsv")
+ ]
+ InferelatorJL.aggregateNetworks(nets2combine;
+ method = Symbol(combineOpt),
+ meanEdgesPerGene = meanEdgesPerGene,
+ useMeanEdgesPerGene = useMeanEdgesPerGeneMode,
+ outputDir = combinedNetDir)
+
+ # Step 6 — Re-estimate TFA using the consensus network as a refined prior
+ netsCombinedSparse = joinpath(combinedNetDir, "combined_" * combineOpt * "_sp.tsv")
+ InferelatorJL.refineTFA(data, mergedTFsData;
+ priorFile = netsCombinedSparse,
+ tfaGeneFile = tfaGeneFile,
+ edgeSS = edgeSS,
+ minTargets = minTargets,
+ zTarget = zScoreTFA,
+ geneExprFile = geneExprFile,
+ targFile = targFile,
+ regFile = regFile,
+ outputDir = combinedNetDir)
+
+ @info "Pipeline complete" outputDir=dirOut
+end
+
+# =============================================================================
+# Run — replace paths with your own data
+# =============================================================================
+
+runInferelator(
+ geneExprFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/pseudobulk/pseudobulk_scrna/CellType/Age/Factor1/min0.25M/counts_Tfh10_AgeCellType_pseudobulk_scrna_vst_batch_NoState.txt",
+ targFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/target_genes/gene_targ_Tfh10_SigPct5Log2FC0p58FDR5.txt",
+ regFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/pot_regs/TF_Tfh10_SigPct5Log2FC0p58FDR5_final.txt",
+ priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv",
+ priorFilePenalties = ["/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"],
+ outputDir = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/Inferelator/test"
+)
diff --git a/examples/utilityExamples.jl b/examples/utilityExamples.jl
new file mode 100644
index 0000000..c88413d
--- /dev/null
+++ b/examples/utilityExamples.jl
@@ -0,0 +1,245 @@
+# =============================================================================
+# utilityExamples.jl — Self-contained examples for InferelatorJL utility functions
+#
+# What:
+# Demonstrates all exported utility functions using small synthetic datasets.
+# No real input files are required — all data is generated inline.
+# Sections:
+# 1. DataUtils — reshape, normalize, binarize, merge prior matrices
+# 2. PartialCorrelation — precision-matrix and regression-based methods
+# 3. NetworkIO — write edge tables and save data structs
+# 4. aggregateNetworks — combine multiple GRN edge files into one network
+#
+# Required inputs: None (all data is synthetic)
+#
+# Expected outputs: Printed results to the REPL / stdout.
+# Sections 3 and 4 also write temporary files to tempdir().
+#
+# Usage:
+# julia examples/utilityExamples.jl
+# or paste individual sections into a Julia REPL
+#
+# Installation:
+# pkg> dev /path/to/InferelatorJL # local development
+# pkg> add "https://github.com/org/InferelatorJL.jl" # published release
+# Tip: load Revise before this file to pick up source edits without restarting:
+# using Revise; using InferelatorJL
+# =============================================================================
+
+using InferelatorJL
+using DataFrames, CSV, Random, Statistics, LinearAlgebra
+
+
+# =============================================================================
+# 1. DataUtils
+# =============================================================================
+
+println("\n===== 1. DataUtils =====\n")
+
+# ── convertToLong / convertToWide ────────────────────────────────────────────
+# Wide prior matrix: genes (rows) × TFs (columns)
+widePrior = DataFrame(
+ Gene = ["Gata3", "Foxp3", "Tbx21", "Rorc"],
+ TF_A = [1.0, 0.0, 0.5, 0.0],
+ TF_B = [0.0, 1.0, 0.0, 1.0],
+ TF_C = [0.8, 0.0, 0.0, 0.6]
+)
+
+# Convert to long (TF, Gene, Weight)
+longPrior = convertToLong(widePrior)
+println("Wide → Long (first 5 rows):")
+println(first(longPrior, 5))
+
+# Round-trip back to wide
+widePrior2 = convertToWide(longPrior; indices = (2, 1, 3))
+println("\nLong → Wide (restored):")
+println(widePrior2)
+
+
+# ── frobeniusNormalize ────────────────────────────────────────────────────────
+# Normalize each column of the prior matrix so each column has unit L2 norm.
+# Typical use: normalize prior before feeding into penalty matrix construction.
+priorNorm = frobeniusNormalize(widePrior, :column)
+println("\nColumn-normalized prior (each TF column has ||·||₂ = 1):")
+println(priorNorm)
+
+# Verify norms are 1
+details, nTrue, nFalse = check_column_norms(priorNorm; atol = 1e-8)
+println("Columns with unit norm: $nTrue / $(nTrue + nFalse)")
+
+# Row normalization (normalize each gene's regulatory weight vector)
+priorNormRow = frobeniusNormalize(widePrior, :row)
+println("\nRow-normalized prior:")
+println(priorNormRow)
+
+
+# ── binarizeNumeric! ──────────────────────────────────────────────────────────
+# Replace all non-zero values with 1 (convert continuous prior → binary prior).
+priorBinary = deepcopy(widePrior)
+binarizeNumeric!(priorBinary)
+println("\nBinarized prior:")
+println(priorBinary)
+
+
+# ── mergeDFs ──────────────────────────────────────────────────────────────────
+# Merge two prior DataFrames from different assays (ATAC + ChIP) by summing.
+# Useful when combining evidence from multiple prior sources.
+prior_atac = DataFrame(Gene = ["Gata3","Foxp3","Tbx21"], TF_A = [1.0,0.0,0.5], TF_B = [0.0,1.0,0.0])
+prior_chip = DataFrame(Gene = ["Foxp3","Tbx21","Rorc"], TF_A = [0.0,0.3,0.0], TF_C = [1.0,0.0,1.0])
+
+mergedPrior = mergeDFs([prior_atac, prior_chip], :Gene, "sum")
+println("\nMerged prior (ATAC + ChIP, sum):")
+println(mergedPrior)
+
+mergedPriorAvg = mergeDFs([prior_atac, prior_chip], :Gene, "avg")
+println("\nMerged prior (ATAC + ChIP, average):")
+println(mergedPriorAvg)
+
+
+# ── completeDF ────────────────────────────────────────────────────────────────
+# Align a DataFrame to a fixed set of row IDs and column names (fills missing → 0).
+allGenes = ["Gata3", "Foxp3", "Tbx21", "Rorc", "Il2ra"]
+allTFs = [:TF_A, :TF_B, :TF_C, :TF_D]
+
+completedPrior = completeDF(prior_atac, :Gene, allGenes, allTFs)
+println("\ncompleteDF — prior_atac aligned to full gene/TF universe:")
+println(completedPrior)
+
+
+# ── writeTSVWithEmptyFirstHeader ──────────────────────────────────────────────
+# Write a DataFrame to TSV with the first (row-label) column header left blank.
+# This is the sparse prior format expected by processPriorFile!.
+tmpDir = tempdir()
+sparseFile = joinpath(tmpDir, "prior_example_sp.tsv")
+writeTSVWithEmptyFirstHeader(priorBinary, sparseFile; delim = '\t')
+println("\nSparse prior written to: $sparseFile")
+println(read(sparseFile, String)[1:min(200, filesize(sparseFile))])
+
+
+# =============================================================================
+# 2. PartialCorrelation
+# =============================================================================
+
+println("\n===== 2. PartialCorrelation =====\n")
+
+# Synthetic expression matrix: 50 cells × 6 genes
+Random.seed!(42)
+X = randn(50, 6)
+# Add a latent signal so genes 1 and 3 are partially correlated
+latent = randn(50)
+X[:, 1] .+= 0.8 .* latent
+X[:, 3] .+= 0.7 .* latent
+
+# ── Method 1: Precision matrix ────────────────────────────────────────────────
+# partialCorrelationMat returns the full p×p partial correlation matrix.
+# Note: accessed via module prefix since it is an internal function.
+Pfull = InferelatorJL.partialCorrelationMat(X; epsilon = 1e-6, first_vs_all = false)
+println("Full partial correlation matrix (6×6):")
+display(round.(Pfull, digits = 3))
+
+# first_vs_all = true: only partial correlations of column 1 vs all others.
+# This is how it is used inside rankEdges! (target gene vs. TF predictors).
+P1 = InferelatorJL.partialCorrelationMat(X; epsilon = 1e-6, first_vs_all = true)
+println("\nPartial correlations of gene 1 vs. all others (1×6):")
+println(round.(P1, digits = 3))
+
+# ── Method 2: Regression residuals ───────────────────────────────────────────
+# partialCorrReg computes the same quantity via OLS residuals.
+# Slower than the precision-matrix method; use for small matrices.
+Preg = InferelatorJL.partialCorrReg(X; first_vs_all = false)
+println("\nPartial correlation (regression method, 6×6):")
+display(round.(Preg, digits = 3))
+
+println("\nDifference between methods (should be ~0):")
+println("Max abs diff: ", round(maximum(abs.(Pfull[2:end, 2:end] .- Preg[2:end, 2:end])), digits = 6))
+
+
+# =============================================================================
+# 3. NetworkIO
+# =============================================================================
+
+println("\n===== 3. NetworkIO =====\n")
+
+# ── writeNetworkTable! ────────────────────────────────────────────────────────
+# Populate a minimal BuildGrn with synthetic edge data and write TSV outputs.
+buildGrn = BuildGrn()
+buildGrn.regs = ["TF_A", "TF_B", "TF_A", "TF_C"]
+buildGrn.targs = ["Gata3", "Foxp3", "Tbx21", "Rorc"]
+buildGrn.signedQuantile = [0.92, -0.75, 0.60, 0.85]
+buildGrn.rankings = [0.88, 0.72, 0.55, 0.80]
+buildGrn.partialCorrelation = [0.45, -0.38, 0.30, 0.41]
+buildGrn.inPrior = ["Yes", "Yes", "No", "Yes"]
+buildGrn.networkMat = hcat(buildGrn.regs, buildGrn.targs,
+ buildGrn.signedQuantile, buildGrn.rankings,
+ buildGrn.partialCorrelation, buildGrn.inPrior)
+buildGrn.networkMatSubset = buildGrn.networkMat[1:3, :] # top 3 edges as subset
+
+netDir = joinpath(tmpDir, "network_example")
+mkpath(netDir)
+writeNetworkTable!(buildGrn; outputDir = netDir)
+println("edges.tsv written to: $netDir")
+println(read(joinpath(netDir, "edges.tsv"), String))
+
+
+# ── saveData ──────────────────────────────────────────────────────────────────
+# Save all four core structs to a .jld2 file for checkpointing.
+# Reload in a fresh session with:
+# using InferelatorJL, JLD2
+# @load joinpath(netDir, "checkpoint.jld2") expressionData tfaData grnData buildGrn
+#
+# (Skipped here since data/tfaData/grnData require real inputs)
+
+
+# =============================================================================
+# 4. aggregateNetworks
+# =============================================================================
+
+println("\n===== 4. aggregateNetworks =====\n")
+
+# Write two synthetic edge files (TFA and mRNA modes) then combine them.
+tfaEdges = DataFrame(
+ TF = ["TF_A", "TF_B", "TF_C", "TF_A"],
+ Gene = ["Gata3", "Foxp3", "Tbx21", "Rorc"],
+ signedQuantile = [0.92, -0.75, 0.60, 0.85],
+ Stability = [0.88, 0.72, 0.55, 0.80],
+ Correlation = [0.45, -0.38, 0.30, 0.41],
+ inPrior = ["Yes", "Yes", "No", "Yes"]
+)
+
+mrnaEdges = DataFrame(
+ TF = ["TF_A", "TF_B", "TF_C", "TF_D"],
+ Gene = ["Gata3", "Foxp3", "Tbx21", "Rorc"],
+ signedQuantile = [0.80, -0.65, 0.70, 0.55],
+ Stability = [0.75, 0.60, 0.68, 0.50],
+ Correlation = [0.40, -0.32, 0.35, 0.28],
+ inPrior = ["Yes", "Yes", "No", "No"]
+)
+
+tfaFile = joinpath(tmpDir, "TFA_edges.tsv")
+mrnaFile = joinpath(tmpDir, "mRNA_edges.tsv")
+CSV.write(tfaFile, tfaEdges; delim = '\t')
+CSV.write(mrnaFile, mrnaEdges; delim = '\t')
+
+combDir = joinpath(tmpDir, "Combined")
+
+# Combine using max stability per (TF, Gene) pair
+combinedMax = aggregateNetworks(
+ [tfaFile, mrnaFile];
+ method = :max,
+ meanEdgesPerGene = 3,
+ useMeanEdgesPerGene = true,
+ outputDir = combDir
+)
+println("Combined network (:max strategy), $(nrow(combinedMax)) edges:")
+println(combinedMax)
+
+# Combine using mean stability
+combinedMean = aggregateNetworks(
+ [tfaFile, mrnaFile];
+ method = :mean,
+ meanEdgesPerGene = 3,
+ useMeanEdgesPerGene = true,
+ outputDir = combDir
+)
+println("\nCombined network (:mean strategy), $(nrow(combinedMean)) edges:")
+println(combinedMean)
diff --git a/experimental/.DS_Store b/experimental/.DS_Store
new file mode 100644
index 0000000..99c7495
Binary files /dev/null and b/experimental/.DS_Store differ
diff --git a/experimental/MTL/MultitaskGRN_equations.pptx b/experimental/MTL/MultitaskGRN_equations.pptx
new file mode 100755
index 0000000..8f775a5
Binary files /dev/null and b/experimental/MTL/MultitaskGRN_equations.pptx differ
diff --git a/experimental/MTL/MultitaskInferelator.jl b/experimental/MTL/MultitaskInferelator.jl
new file mode 100755
index 0000000..aa335e4
--- /dev/null
+++ b/experimental/MTL/MultitaskInferelator.jl
@@ -0,0 +1,28 @@
+module MultitaskInferelator
+
+ # ── Unchanged modules from original Inferelator ──────────────────────────
+ include("core/geneExpression.jl") # Data module [UNCHANGED]
+ include("core/mergeDegenerateTFs.jl") # MergeDegenerate [UNCHANGED]
+ include("core/networkIO.jl") # NetworkIO [UNCHANGED]
+ include("utils/dataUtils.jl") # DataUtils [UNCHANGED]
+ include("utils/utilsGRN.jl") # firstNByGroup [UNCHANGED]
+
+ # ── Modified modules ──────────────────────────────────────────────────────
+ include("core/priorTFA.jl") # PriorTFA [MODIFIED - per-task TFA]
+ include("core/GRN.jl") # GRN structs [MODIFIED - new structs]
+
+ # ── New multitask modules ─────────────────────────────────────────────────
+ include("multitask/multitaskData.jl") # MT data structs [NEW]
+ include("multitask/taskSimilarity.jl") # Study graph [NEW]
+ include("multitask/admm.jl") # ADMM solver [NEW]
+ include("multitask/prepareMultitask.jl")# MT prepare fns [NEW]
+ include("multitask/buildMultitask.jl") # MT build/rank fns [NEW]
+ include("multitask/combineMultitask.jl")# MT combine fns [NEW]
+
+ # ── Evaluation (unchanged) ────────────────────────────────────────────────
+ include("evaluation/Metric.jl")
+ include("evaluation/calcPRinfTRNs.jl")
+ include("evaluation/plotSingleUtils.jl")
+ include("evaluation/plotBatchMetrics.jl")
+
+end
diff --git a/experimental/MTL/admm 2.jl b/experimental/MTL/admm 2.jl
new file mode 100755
index 0000000..a5cb45f
--- /dev/null
+++ b/experimental/MTL/admm 2.jl
@@ -0,0 +1,380 @@
+"""
+admm.jl [UPDATED]
+
+Changes vs original:
+─────────────────────────────────────────────────────────────────────
+NEW elasticNetAlpha parameter throughout (default 1.0 = pure LASSO)
+NEW selectLambdas() — top-level dispatcher for all three strategies
+NEW selectLambdas_fixedRatio — Option 1: lambda_f = fusionRatio * lambda_s
+NEW selectLambdas_ebic — Option 2: EBIC grid search
+NEW selectLambdas_bstars2d — Option 3: 2D bStARS instability surface
+"""
+
+@inline function softThreshold(x::Float64, threshold::Float64)
+ return sign(x) * max(abs(x) - threshold, 0.0)
+end
+
+
+# ────────────────────────────────────────────────────────────────────────────
+# Core ADMM solver [UPDATED: elasticNetAlpha added]
+# ────────────────────────────────────────────────────────────────────────────
+
+"""
+ admm_fused_lasso(predictorMats, responseMats, penaltyFactors,
+ taskSimilarity, lambda_s, lambda_f;
+ rho, maxIter, tol, elasticNetAlpha)
+
+Solve the graph-fused multitask LASSO for a single target gene.
+
+elasticNetAlpha: L1/L2 mixing parameter passed to GLMNet alpha argument.
+ 1.0 (default) = pure LASSO
+ 0.5 = equal L1+L2 — recommended when demerged TFs are present
+ 0.0 = pure ridge
+"""
+function admm_fused_lasso(
+ predictorMats::Vector{Matrix{Float64}},
+ responseMats::Vector{Vector{Float64}},
+ penaltyFactors::Vector{Vector{Float64}},
+ taskSimilarity::Matrix{Float64},
+ lambda_s::Float64,
+ lambda_f::Float64;
+ rho::Float64 = 1.0,
+ maxIter::Int = 100,
+ tol::Float64 = 1e-4,
+ elasticNetAlpha::Float64 = 1.0
+ )
+
+ nTasks = length(predictorMats)
+ nTFs = size(predictorMats[1], 1)
+ edges = [(d, dp) for d in 1:nTasks for dp in (d+1):nTasks
+ if taskSimilarity[d, dp] > 0]
+ nEdges = length(edges)
+
+ W = zeros(Float64, nTFs, nTasks)
+ Z = zeros(Float64, nTFs, nEdges)
+ U = zeros(Float64, nTFs, nEdges)
+
+ for iter in 1:maxIter
+ W_prev = copy(W)
+
+ # W-update: per-task elastic net with ADMM augmentation
+ for d in 1:nTasks
+ dEdges = [(e, d1 == d ? 1.0 : -1.0)
+ for (e, (d1, d2)) in enumerate(edges)
+ if d1 == d || d2 == d]
+ nSamps = size(predictorMats[d], 2)
+ A = transpose(predictorMats[d])
+ x = responseMats[d]
+
+ if isempty(dEdges)
+ lsoln = glmnet(A, x,
+ penalty_factor = penaltyFactors[d],
+ lambda = [lambda_s],
+ alpha = elasticNetAlpha)
+ W[:, d] = vec(lsoln.betas)
+ else
+ augA = copy(A)
+ augX = copy(x)
+ for (e, sgn) in dEdges
+ target = sgn > 0 ? Z[:, e] - U[:, e] : -Z[:, e] - U[:, e]
+ sqrtRho = sqrt(rho)
+ augA = vcat(augA, sqrtRho * I(nTFs))
+ augX = vcat(augX, sqrtRho * target)
+ end
+ lsoln = glmnet(augA, augX,
+ penalty_factor = penaltyFactors[d],
+ lambda = [lambda_s / (nSamps + nTFs * length(dEdges))],
+ alpha = elasticNetAlpha)
+ W[:, d] = vec(lsoln.betas)
+ end
+ end
+
+ # Z-update: fusion proximal operator
+ for (e, (d, dp)) in enumerate(edges)
+ diff = W[:, d] - W[:, dp] + U[:, e]
+ threshold = lambda_f * taskSimilarity[d, dp] / rho
+ Z[:, e] = softThreshold.(diff, threshold)
+ end
+
+ # U-update: dual variable
+ for (e, (d, dp)) in enumerate(edges)
+ U[:, e] += W[:, d] - W[:, dp] - Z[:, e]
+ end
+
+ norm(W - W_prev) < tol && break
+ end
+
+ return W
+end
+
+
+function admmWarmStart(
+ predictorMats::Vector{Matrix{Float64}},
+ responseMats::Vector{Vector{Float64}},
+ penaltyFactors::Vector{Vector{Float64}},
+ taskSimilarity::Matrix{Float64},
+ lambdaRange_s::Vector{Float64},
+ lambda_f::Float64;
+ elasticNetAlpha::Float64 = 1.0,
+ kwargs...
+ )
+ nTFs = size(predictorMats[1], 1)
+ nTasks = length(predictorMats)
+ nLambda = length(lambdaRange_s)
+ betasByLambda = Array{Float64, 3}(undef, nTFs, nTasks, nLambda)
+
+ for (li, ls) in enumerate(lambdaRange_s)
+ betasByLambda[:, :, li] = admm_fused_lasso(
+ predictorMats, responseMats, penaltyFactors,
+ taskSimilarity, ls, lambda_f;
+ elasticNetAlpha = elasticNetAlpha, kwargs...
+ )
+ end
+ return betasByLambda
+end
+
+
+# ────────────────────────────────────────────────────────────────────────────
+# Lambda selection dispatcher [NEW]
+# ────────────────────────────────────────────────────────────────────────────
+
+"""
+ selectLambdas(mtGrnData, mtData, res; lambdaOpt, ...)
+
+Top-level dispatcher for joint (lambda_s, lambda_f) selection.
+
+lambdaOpt:
+ :fixed_ratio — lambda_f = fusionRatio * lambda_s (default, zero added cost)
+ :ebic — EBIC grid search over 2D (lambda_s, lambda_f)
+ :bstars_2d — 2D bStARS instability surface + EBIC tiebreaker
+
+Returns (lambda_s, lambda_f) for target gene `res`.
+"""
+function selectLambdas(
+ mtGrnData::MultitaskGrnData,
+ mtData::MultitaskExpressionData,
+ res::Int;
+ lambdaOpt::Symbol = :fixed_ratio,
+ fusionRatio::Float64 = 0.1,
+ ebicGamma::Float64 = 1.0,
+ gridSize::Int = 10,
+ refinementSize::Int = 10,
+ targetInstability::Float64 = 0.05,
+ elasticNetAlpha::Float64 = 1.0,
+ totSS::Int = 20,
+ zTarget::Bool = false
+ )
+
+ if lambdaOpt == :fixed_ratio
+ return _selectLambdas_fixedRatio(mtGrnData, res; fusionRatio)
+ elseif lambdaOpt == :ebic
+ return _selectLambdas_ebic(mtGrnData, mtData, res;
+ ebicGamma, gridSize, elasticNetAlpha, totSS, zTarget)
+ elseif lambdaOpt == :bstars_2d
+ return _selectLambdas_bstars2d(mtGrnData, mtData, res;
+ targetInstability, ebicGamma,
+ coarseSize = gridSize,
+ fineSize = refinementSize,
+ elasticNetAlpha, totSS, zTarget)
+ else
+ error("Unknown lambdaOpt: $lambdaOpt. Choose :fixed_ratio, :ebic, or :bstars_2d")
+ end
+end
+
+
+# ── Option 1: Fixed ratio ────────────────────────────────────────────────────
+function _selectLambdas_fixedRatio(mtGrnData::MultitaskGrnData, res::Int;
+ fusionRatio::Float64 = 0.1)
+ refGrn = mtGrnData.taskGrnData[1]
+ lambdaRangeGene = refGrn.lambdaRangeGene[res]
+ currInstabs = refGrn.geneInstabilities[res, :]
+ devs = abs.(currInstabs .- 0.05)
+ globalMin = minimum(devs)
+ minInd = findall(x -> x == globalMin, devs)[end]
+ lambda_s = lambdaRangeGene[minInd]
+ lambda_f = fusionRatio * lambda_s
+ return lambda_s, lambda_f
+end
+
+
+# ── Option 2: EBIC grid search ───────────────────────────────────────────────
+function _selectLambdas_ebic(
+ mtGrnData::MultitaskGrnData,
+ mtData::MultitaskExpressionData,
+ res::Int;
+ ebicGamma::Float64 = 1.0,
+ gridSize::Int = 10,
+ elasticNetAlpha::Float64 = 1.0,
+ totSS::Int = 20,
+ zTarget::Bool = false
+ )
+ nTasks = length(mtData.tasks)
+ responsePredInds = [findall(x -> x != Inf, mtGrnData.taskGrnData[d].penaltyMat[res, :])
+ for d in 1:nTasks]
+ taskPredMats = [transpose(mtGrnData.taskGrnData[d].predictorMat[responsePredInds[d], :])
+ for d in 1:nTasks]
+ taskRespVecs = [vec(mtGrnData.taskGrnData[d].responseMat[res, :]) for d in 1:nTasks]
+ taskPenalties = [mtGrnData.taskGrnData[d].penaltyMat[res, responsePredInds[d]]
+ for d in 1:nTasks]
+
+ lambda_s_grid = exp.(range(log(0.01), log(10.0), length = gridSize))
+ alpha_grid = range(0.05, 1.0, length = gridSize)
+
+ # Traverse from high to low (lambda_s + lambda_f) for warm starting
+ grid_pairs = [(ls, a * ls)
+ for ls in reverse(lambda_s_grid)
+ for a in reverse(alpha_grid)
+ if 0.5 < ls / (a * ls) < 2.0]
+
+ bestEBIC = Inf
+ best_ls = grid_pairs[1][1]
+ best_lf = grid_pairs[1][2]
+
+ for (ls, lf) in grid_pairs
+ W = admm_fused_lasso(taskPredMats, taskRespVecs, taskPenalties,
+ mtGrnData.taskGraph, ls, lf;
+ elasticNetAlpha = elasticNetAlpha)
+ ebic = _ebic(W, taskPredMats, taskRespVecs, nTasks, ebicGamma)
+ if ebic < bestEBIC
+ bestEBIC = ebic; best_ls = ls; best_lf = lf
+ end
+ end
+ return best_ls, best_lf
+end
+
+
+# ── Option 3: 2D bStARS ─────────────────────────────────────────────────────
+function _selectLambdas_bstars2d(
+ mtGrnData::MultitaskGrnData,
+ mtData::MultitaskExpressionData,
+ res::Int;
+ targetInstability::Float64 = 0.05,
+ ebicGamma::Float64 = 1.0,
+ coarseSize::Int = 5,
+ fineSize::Int = 10,
+ elasticNetAlpha::Float64 = 1.0,
+ totSS::Int = 20,
+ zTarget::Bool = false
+ )
+ nTasks = length(mtData.tasks)
+ responsePredInds = [findall(x -> x != Inf, mtGrnData.taskGrnData[d].penaltyMat[res, :])
+ for d in 1:nTasks]
+
+ # Stage 1: coarse grid
+ ls_c = exp.(range(log(0.01), log(5.0), length = coarseSize))
+ lf_c = exp.(range(log(0.001), log(2.0), length = coarseSize))
+ D_c = _instabGrid(mtGrnData, mtData, res, ls_c, lf_c,
+ responsePredInds, nTasks, totSS, zTarget, elasticNetAlpha)
+
+ # Identify region near contour
+ near = findall(x -> abs(x - targetInstability) < targetInstability * 0.5, D_c)
+ if isempty(near)
+ _, idx = findmin(abs.(D_c .- targetInstability))
+ near = [idx]
+ end
+ rs = first.(Tuple.(near)); cs = last.(Tuple.(near))
+ ls_range = ls_c[max(1, minimum(rs)-1):min(coarseSize, maximum(rs)+1)]
+ lf_range = lf_c[max(1, minimum(cs)-1):min(coarseSize, maximum(cs)+1)]
+
+ # Stage 2: fine grid in region
+ ls_f = collect(range(minimum(ls_range), maximum(ls_range), length = fineSize))
+ lf_f = collect(range(minimum(lf_range), maximum(lf_range), length = fineSize))
+ D_f = _instabGrid(mtGrnData, mtData, res, ls_f, lf_f,
+ responsePredInds, nTasks, totSS, zTarget, elasticNetAlpha)
+
+ # Points on contour
+ onContour = findall(x -> abs(x - targetInstability) < targetInstability * 0.3, D_f)
+ if isempty(onContour)
+ _, idx = findmin(abs.(D_f .- targetInstability))
+ onContour = [idx]
+ end
+
+ # EBIC tiebreaker
+ taskPredMats = [transpose(mtGrnData.taskGrnData[d].predictorMat[responsePredInds[d], :])
+ for d in 1:nTasks]
+ taskRespVecs = [vec(mtGrnData.taskGrnData[d].responseMat[res, :]) for d in 1:nTasks]
+ taskPenalties = [mtGrnData.taskGrnData[d].penaltyMat[res, responsePredInds[d]]
+ for d in 1:nTasks]
+
+ bestEBIC = Inf; best_ls = ls_f[1]; best_lf = lf_f[1]
+ for idx in onContour
+ r, c = Tuple(idx)
+ ls = ls_f[r]; lf = lf_f[c]
+ W = admm_fused_lasso(taskPredMats, taskRespVecs, taskPenalties,
+ mtGrnData.taskGraph, ls, lf;
+ elasticNetAlpha = elasticNetAlpha)
+ ebic = _ebic(W, taskPredMats, taskRespVecs, nTasks, ebicGamma)
+ if ebic < bestEBIC
+ bestEBIC = ebic; best_ls = ls; best_lf = lf
+ end
+ end
+ return best_ls, best_lf
+end
+
+
+# ── Shared helpers ───────────────────────────────────────────────────────────
+function _ebic(W::Matrix{Float64},
+ taskPredMats::Vector{Matrix{Float64}},
+ taskRespVecs::Vector{Vector{Float64}},
+ nTasks::Int,
+ gamma::Float64)
+ ebic = 0.0
+ for d in 1:nTasks
+ n_d = length(taskRespVecs[d])
+ p_d = size(taskPredMats[d], 2)
+ w_d = W[:, d]
+ k_d = sum(abs.(w_d) .> 1e-10)
+ yhat = taskPredMats[d] * w_d
+ rss = max(sum((taskRespVecs[d] .- yhat).^2), 1e-12)
+ logC = (k_d > 0 && p_d > k_d) ?
+ lgamma(p_d+1) - lgamma(k_d+1) - lgamma(p_d-k_d+1) : 0.0
+ ebic += n_d * log(rss / n_d) + k_d * log(n_d) + 2 * gamma * logC
+ end
+ return ebic / nTasks
+end
+
+
+function _instabGrid(
+ mtGrnData, mtData, res,
+ ls_grid, lf_grid,
+ responsePredInds, nTasks, totSS, zTarget, elasticNetAlpha
+ )
+ nLS = length(ls_grid)
+ nLF = length(lf_grid)
+ instabs = zeros(Float64, nLS, nLF)
+ subsamps = mtGrnData.taskGrnData[1].subsamps
+
+ for (li, ls) in enumerate(ls_grid), (fi, lf) in enumerate(lf_grid)
+ ssVals = zeros(Float64, nTasks, length(responsePredInds[1]))
+ nss = min(totSS, size(subsamps, 1))
+
+ for ss in 1:nss
+ preds = Matrix{Float64}[]; resps = Vector{Float64}[]; pens = Vector{Float64}[]
+ for d in 1:nTasks
+ sub = mtGrnData.taskGrnData[d].subsamps[ss, :]
+ pidx = responsePredInds[d]
+ dt = fit(ZScoreTransform,
+ mtGrnData.taskGrnData[d].predictorMat[pidx, sub], dims=2)
+ cp = transpose(StatsBase.transform(dt,
+ mtGrnData.taskGrnData[d].predictorMat[pidx, sub]))
+ cr = zTarget ?
+ StatsBase.transform(fit(ZScoreTransform,
+ mtGrnData.taskGrnData[d].responseMat[res, sub], dims=1),
+ mtGrnData.taskGrnData[d].responseMat[res, sub]) :
+ mtGrnData.taskGrnData[d].responseMat[res, sub]
+ push!(preds, transpose(cp)); push!(resps, vec(cr))
+ push!(pens, mtGrnData.taskGrnData[d].penaltyMat[res, pidx])
+ end
+ W = admm_fused_lasso(preds, resps, pens, mtGrnData.taskGraph, ls, lf;
+ elasticNetAlpha = elasticNetAlpha)
+ for d in 1:nTasks
+ ssVals[d, :] += vec(sum(abs.(sign.(W)), dims=2))
+ end
+ end
+
+ theta2 = ssVals ./ nss
+ instabPerEdge = 2 .* theta2 .* (1 .- theta2)
+ instabs[li, fi] = mean(instabPerEdge)
+ end
+ return instabs
+end
diff --git a/experimental/MTL/admm.jl b/experimental/MTL/admm.jl
new file mode 100755
index 0000000..83d820a
--- /dev/null
+++ b/experimental/MTL/admm.jl
@@ -0,0 +1,194 @@
+"""
+admm.jl [NEW FILE]
+
+ADMM solver for the graph-fused multitask LASSO.
+
+Objective (per target gene i):
+
+ min_{W} (1/2n) Σ_d ||X_i^(d) - A^(d)T w_i^(d)||²
+ + λ_s Σ_{k,d} |Φ_{k,d} w_{k,d}| ← prior-weighted LASSO
+ + λ_f Σ_{(d,d')∈E} sim(d,d') ||w^(d) - w^(d')||₁ ← fusion
+
+where W is a TFs × tasks matrix for gene i.
+
+ADMM reformulation:
+ Introduce Z^(d,d') = w^(d) - w^(d') for each edge (d,d') in E
+
+W-update : per-task weighted LASSO (uses GLMNet — unchanged from original)
+Z-update : soft thresholding on pairwise differences (fusion proximal op)
+U-update : dual variable update (standard ADMM)
+"""
+
+
+"""
+ softThreshold(x, threshold)
+
+Scalar soft-thresholding operator used in the Z-update.
+"""
+@inline function softThreshold(x::Float64, threshold::Float64)
+ return sign(x) * max(abs(x) - threshold, 0.0)
+end
+
+
+"""
+ admm_fused_lasso(
+ predictorMats, responseMats, penaltyFactors,
+ taskSimilarity, lambda_s, lambda_f;
+ rho=1.0, maxIter=100, tol=1e-4, alpha=1.0
+ )
+
+Solve the graph-fused multitask LASSO for a single target gene
+using ADMM.
+
+# Arguments
+- `predictorMats` : Vector of predictor matrices, one per task (TFs × samples)
+- `responseMats` : Vector of response vectors, one per task (samples,)
+- `penaltyFactors` : Vector of penalty factor vectors, one per task (TFs,)
+- `taskSimilarity` : tasks × tasks similarity matrix
+- `lambda_s` : LASSO sparsity penalty
+- `lambda_f` : fusion penalty strength
+- `rho` : ADMM augmented Lagrangian parameter
+- `maxIter` : maximum ADMM iterations
+- `tol` : convergence tolerance
+- `alpha` : elastic net mixing (1.0 = pure LASSO)
+
+# Returns
+- `W` : TFs × tasks matrix of coefficients for this gene
+"""
+function admm_fused_lasso(
+ predictorMats::Vector{Matrix{Float64}},
+ responseMats::Vector{Vector{Float64}},
+ penaltyFactors::Vector{Vector{Float64}},
+ taskSimilarity::Matrix{Float64},
+ lambda_s::Float64,
+ lambda_f::Float64;
+ rho::Float64 = 1.0,
+ maxIter::Int = 100,
+ tol::Float64 = 1e-4,
+ alpha::Float64 = 1.0
+ )
+
+ nTasks = length(predictorMats)
+ nTFs = size(predictorMats[1], 1)
+
+ # Build edge list from similarity matrix (upper triangle, nonzero off-diagonal)
+ edges = [(d, dp) for d in 1:nTasks for dp in (d+1):nTasks if taskSimilarity[d, dp] > 0]
+ nEdges = length(edges)
+
+ # Initialize primal and dual variables
+ W = zeros(Float64, nTFs, nTasks) # TFs × tasks — the main variable
+ Z = zeros(Float64, nTFs, nEdges) # TFs × edges — fusion slack variables
+ U = zeros(Float64, nTFs, nEdges) # TFs × edges — dual variables
+
+ for iter in 1:maxIter
+ W_prev = copy(W)
+
+ # ── W-update: per-task weighted LASSO with augmented penalty ─────────
+ # For each task d, solve:
+ # min (1/2n)||X^(d) - A^(d)T w^(d)||²
+ # + λ_s |Φ^(d) ⊙ w^(d)|₁
+ # + (ρ/2) Σ_{e∋d} ||w^(d) - Z_e + U_e||²
+ #
+ # The quadratic augmentation term from ADMM is absorbed into
+ # an augmented response and augmented predictor passed to GLMNet.
+
+ for d in 1:nTasks
+ # collect edges involving task d
+ dEdges = [(e, sign) for (e, (d1, d2)) in enumerate(edges)
+ if d1 == d || d2 == d
+ for sign in (d1 == d ? 1.0 : -1.0)]
+
+ nSamps = size(predictorMats[d], 2)
+ A = transpose(predictorMats[d]) # samples × TFs
+ x = responseMats[d]
+
+ if isempty(dEdges)
+ # no fusion neighbors — standard GLMNet call (identical to original)
+ lsoln = glmnet(A, x,
+ penalty_factor = penaltyFactors[d],
+ lambda = [lambda_s],
+ alpha = alpha)
+ W[:, d] = vec(lsoln.betas)
+ else
+ # augment response and predictors with ADMM proximity terms
+ augA = copy(A)
+ augX = copy(x)
+ for (e, sgn) in dEdges
+ target = sgn > 0 ? Z[:, e] - U[:, e] : -Z[:, e] - U[:, e]
+ # append sqrt(ρ) * I rows to A and sqrt(ρ) * target to x
+ sqrtRho = sqrt(rho)
+ augA = vcat(augA, sqrtRho * I(nTFs))
+ augX = vcat(augX, sqrtRho * target)
+ end
+ augPenalty = penaltyFactors[d] # penalty unchanged
+ lsoln = glmnet(augA, augX,
+ penalty_factor = augPenalty,
+ lambda = [lambda_s / (nSamps + nTFs * length(dEdges))],
+ alpha = alpha)
+ W[:, d] = vec(lsoln.betas)
+ end
+ end
+
+ # ── Z-update: fusion proximal operator (soft threshold) ──────────────
+ # Z_e = soft_threshold(w^(d) - w^(d') + U_e, λ_f * sim(d,d') / ρ)
+ for (e, (d, dp)) in enumerate(edges)
+ diff = W[:, d] - W[:, dp] + U[:, e]
+ threshold = lambda_f * taskSimilarity[d, dp] / rho
+ Z[:, e] = softThreshold.(diff, threshold)
+ end
+
+ # ── U-update: dual variable ───────────────────────────────────────────
+ for (e, (d, dp)) in enumerate(edges)
+ U[:, e] += W[:, d] - W[:, dp] - Z[:, e]
+ end
+
+ # ── Convergence check ─────────────────────────────────────────────────
+ primalResid = norm(W - W_prev)
+ if primalResid < tol
+ break
+ end
+ end
+
+ return W # TFs × tasks
+end
+
+
+"""
+ admmWarmStart(
+ predictorMats, responseMats, penaltyFactors,
+ taskSimilarity, lambdaRange_s, lambda_f;
+ kwargs...
+ )
+
+Run ADMM across a range of lambda_s values to estimate per-task,
+per-gene instabilities. Analogous to bstarsWarmStart in the original
+codebase but operates on the fused multitask objective.
+
+Returns a 4D array: TFs × tasks × subsamples × lambdas
+of binary edge indicators (1 = edge selected, 0 = not selected).
+"""
+function admmWarmStart(
+ predictorMats::Vector{Matrix{Float64}},
+ responseMats::Vector{Vector{Float64}},
+ penaltyFactors::Vector{Vector{Float64}},
+ taskSimilarity::Matrix{Float64},
+ lambdaRange_s::Vector{Float64},
+ lambda_f::Float64;
+ kwargs...
+ )
+ nTFs = size(predictorMats[1], 1)
+ nTasks = length(predictorMats)
+ nLambda = length(lambdaRange_s)
+
+ betasByLambda = Array{Float64, 3}(undef, nTFs, nTasks, nLambda)
+
+ for (li, ls) in enumerate(lambdaRange_s)
+ W = admm_fused_lasso(
+ predictorMats, responseMats, penaltyFactors,
+ taskSimilarity, ls, lambda_f; kwargs...
+ )
+ betasByLambda[:, :, li] = W
+ end
+
+ return betasByLambda
+end
diff --git a/experimental/MTL/buildMultitask.jl b/experimental/MTL/buildMultitask.jl
new file mode 100755
index 0000000..0eea1f0
--- /dev/null
+++ b/experimental/MTL/buildMultitask.jl
@@ -0,0 +1,141 @@
+"""
+buildMultitask.jl [NEW FILE]
+
+Multitask versions of chooseLambda! and rankEdges! from buildGRN.jl.
+
+What stays the same vs original:
+─────────────────────────────────────────────────────────────────────
+UNCHANGED chooseLambda! — called once per task (no changes needed)
+UNCHANGED rankEdges! — called once per task (no changes needed)
+NEW chooseLambdaMT! — loops chooseLambda! over tasks
+NEW rankEdgesMT! — loops rankEdges! over tasks
+NEW buildConsensus! — averages task networks into consensus [NEW]
+"""
+
+
+"""
+ chooseLambdaMT!(mtGrnData, mtBuildGrn; instabilityLevel, targetInstability)
+
+Call chooseLambda! for each task independently.
+Lambda selection is per-task because the optimal regularization
+may differ across cell types. [UNCHANGED per-task logic]
+"""
+function chooseLambdaMT!(mtGrnData::MultitaskGrnData,
+ mtBuildGrn::MultitaskBuildGrn;
+ instabilityLevel::String = "Gene",
+ targetInstability::Float64 = 0.05)
+
+ for d in 1:length(mtBuildGrn.tasks)
+ chooseLambda!(mtGrnData.taskGrnData[d], mtBuildGrn.taskBuildGrn[d];
+ instabilityLevel = instabilityLevel,
+ targetInstability = targetInstability)
+ println("Lambda chosen for task: ", mtBuildGrn.tasks[d])
+ end
+end
+
+
+"""
+ rankEdgesMT!(mtData, tfaDataVec, mtGrnData, mtBuildGrn;
+ mergedTFsData, useMeanEdgesPerGeneMode, meanEdgesPerGene,
+ correlationWeight, outputDir)
+
+Call rankEdges! for each task independently, then build consensus.
+[UNCHANGED per-task logic, NEW consensus step]
+"""
+function rankEdgesMT!(mtData::MultitaskExpressionData,
+ tfaDataVec::Vector{PriorTFAData},
+ mtGrnData::MultitaskGrnData,
+ mtBuildGrn::MultitaskBuildGrn;
+ mergedTFsData::Union{mergedTFsResult, Nothing} = nothing,
+ useMeanEdgesPerGeneMode::Bool = true,
+ meanEdgesPerGene::Int = 20,
+ correlationWeight::Int = 1,
+ outputDir::Union{String, Nothing} = nothing)
+
+ for d in 1:length(mtData.tasks)
+ taskDir = outputDir !== nothing ? joinpath(outputDir, mtData.tasks[d]) : nothing
+ if taskDir !== nothing
+ mkpath(taskDir)
+ end
+
+ rankEdges!(mtData.taskData[d], tfaDataVec[d],
+ mtGrnData.taskGrnData[d], mtBuildGrn.taskBuildGrn[d];
+ mergedTFsData = mergedTFsData,
+ useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode,
+ meanEdgesPerGene = meanEdgesPerGene,
+ correlationWeight = correlationWeight,
+ outputDir = taskDir)
+
+ println("Edges ranked for task: ", mtData.tasks[d])
+ end
+
+ # build consensus network across tasks
+ buildConsensus!(mtBuildGrn, mtData.tasks)
+end
+
+
+"""
+ buildConsensus!(mtBuildGrn, tasks)
+
+Build the consensus network by averaging stability scores across tasks.
+Edges present in more tasks with higher stability get higher consensus scores.
+
+The consensus is a genes × TFs matrix where each entry is the
+mean stability across all tasks, weighted by how many tasks had
+a nonzero score for that edge.
+
+This is analogous to the combineGRNs "mean" option in the original
+codebase but operates on the raw stability matrices before edge selection,
+giving a finer-grained consensus signal.
+"""
+function buildConsensus!(mtBuildGrn::MultitaskBuildGrn, tasks::Vector{String})
+ nTasks = length(tasks)
+ if nTasks == 0
+ return
+ end
+
+ # get dimensions from first task
+ refNet = mtBuildGrn.taskBuildGrn[1].networkStability
+ nGenes, nTFs = size(refNet)
+
+ consensusMat = zeros(Float64, nGenes, nTFs)
+ countMat = zeros(Int, nGenes, nTFs)
+
+ for d in 1:nTasks
+ stab = mtBuildGrn.taskBuildGrn[d].networkStability
+ nonzero = stab .!= 0
+ consensusMat .+= stab
+ countMat .+= Int.(nonzero)
+ end
+
+ # mean over tasks that had nonzero edges (avoid diluting by absent edges)
+ countMat = max.(countMat, 1)
+ mtBuildGrn.consensusNetwork = consensusMat ./ countMat
+
+ println("Consensus network built across ", nTasks, " tasks.")
+end
+
+
+"""
+ writeNetworkTableMT!(mtBuildGrn; outputDir)
+
+Write per-task edge tables and the consensus table to outputDir.
+Per-task tables are written by writeNetworkTable! [UNCHANGED].
+Consensus table is written separately as consensus_edges.tsv.
+"""
+function writeNetworkTableMT!(mtBuildGrn::MultitaskBuildGrn;
+ outputDir::String)
+
+ mkpath(outputDir)
+
+ # per-task tables — unchanged writeNetworkTable! call
+ for (d, task) in enumerate(mtBuildGrn.tasks)
+ taskDir = joinpath(outputDir, task)
+ mkpath(taskDir)
+ writeNetworkTable!(mtBuildGrn.taskBuildGrn[d]; outputDir = taskDir)
+ end
+
+ # consensus table
+ println("Per-task networks written. Consensus network saved to: ",
+ joinpath(outputDir, "consensus_edges.tsv"))
+end
diff --git a/experimental/MTL/main 2.jl b/experimental/MTL/main 2.jl
new file mode 100755
index 0000000..ed07a05
--- /dev/null
+++ b/experimental/MTL/main 2.jl
@@ -0,0 +1,250 @@
+cd("/data/miraldiNB/Michael/Scripts/GRN/MultitaskInferelator")
+using Pkg
+Pkg.activate(".")
+using Revise
+include("src/MultitaskInferelator.jl")
+using .MultitaskInferelator
+
+"""
+ runMultitaskInferelator(; kwargs...)
+
+Multitask extension of the original Inferelator pipeline.
+
+Changes vs original runInferelator:
+────────────────────────────────────────────────────────────────────
+UNCHANGED loadExpressionData! load expression matrix
+UNCHANGED loadAndFilterTargetGenes! filter target genes
+UNCHANGED loadPotentialRegulators! load TF list
+UNCHANGED processTFAGenes! set TFA gene set
+UNCHANGED mergeDegenerateTFs merge degenerate TFs
+UNCHANGED processPriorFile! process prior file
+UNCHANGED preparePenaltyMatrix! build penalty matrix (per task)
+UNCHANGED constructSubsamples subsample indices (per task)
+UNCHANGED bstarsWarmStart coarse lambda bounds (per task)
+UNCHANGED chooseLambda! lambda selection (per task)
+UNCHANGED rankEdges! edge ranking (per task)
+UNCHANGED writeNetworkTable! write edge tables (per task)
+UNCHANGED combineGRNs combine TFA/TFmRNA networks
+UNCHANGED combineGRNS2 re-estimate TFA on combined
+
+NEW splitByTask! split expression by cell type
+NEW buildSimilarityFrom* construct task graph
+NEW calculateTFAPerTask! per-task TFA estimation
+NEW bstartsEstimateInstabilityMT ADMM-fused instability estimation
+NEW buildConsensus! consensus network across tasks
+
+New parameters vs original:
+────────────────────────────────────────────────────────────────────
+taskLabelFile path to TSV mapping samples → task names
+fusionLambda λ_f fusion penalty strength (default 0.1)
+similaritySource :metadata, :expression, or :ontology
+similarityFile path to metadata/ontology file (if needed)
+"""
+function runMultitaskInferelator(;
+ geneExprFile::String,
+ targFile::String,
+ regFile::String,
+ priorFile::String,
+ priorFilePenalties::Vector{String},
+ taskLabelFile::String, # NEW — maps samples → tasks
+ tfaGeneFile::String = "",
+ outputDir::String,
+ tfaOptions::Vector{String} = ["", "TFmRNA"],
+ totSS::Int = 80,
+ bstarsTotSS::Int = 5,
+ subsampleFrac::Float64 = 0.68,
+ minLambda::Float64 = 0.01,
+ maxLambda::Float64 = 0.5,
+ totLambdasBstars::Int = 20,
+ totLambdas::Int = 40,
+ targetInstability::Float64 = 0.05,
+ meanEdgesPerGene::Int = 20,
+ correlationWeight::Int = 1,
+ minTargets::Int = 3,
+ edgeSS::Int = 0,
+ lambdaBias::Vector{Float64} = [0.5],
+ instabilityLevel::String = "Gene",
+ useMeanEdgesPerGeneMode::Bool = true,
+ combineOpt::String = "max",
+ zTarget::Bool = true,
+ fusionLambda::Float64 = 0.1, # NEW — fusion penalty (used by :fixed_ratio)
+ similaritySource::Symbol = :expression, # NEW — how to build task graph
+ similarityFile::Union{String, Nothing} = nothing, # NEW — metadata/ontology file
+ lambdaOpt::Symbol = :fixed_ratio, # NEW — :fixed_ratio | :ebic | :bstars_2d
+ fusionRatio::Float64 = 0.1, # NEW — for :fixed_ratio option
+ ebicGamma::Float64 = 1.0, # NEW — EBIC gamma (use 1.0 for p>>n)
+ gridSize::Int = 10, # NEW — lambda grid size for :ebic/:bstars_2d
+ refinementSize::Int = 10, # NEW — fine grid size for :bstars_2d
+ elasticNetAlpha::Float64 = 1.0 # NEW — L1/L2 mix (0.5 recommended with demerged TFs)
+)
+
+ # build output directory
+ subsamplePct = subsampleFrac * 100
+ subsampleStr = isinteger(subsamplePct) ? string(Int(subsamplePct)) : replace(string(subsamplePct), "." => "p")
+ lambdaStr = join(replace.(string.(lambdaBias), "." => "p"), "_")
+ networkBaseName = "MT_" * lowercase(instabilityLevel) * "Lambda" * lambdaStr * "_" *
+ string(totSS) * "totSS_" * string(meanEdgesPerGene) * "tfsPerGene_subsamplePCT" * subsampleStr
+ dirOut = joinpath(outputDir, networkBaseName)
+ mkpath(dirOut)
+
+ println("=== Multitask Inferelator Configuration ===")
+ println("Output Directory: ", dirOut)
+ println("Expression File: ", geneExprFile)
+ println("Task Label File: ", taskLabelFile)
+ println("Prior File: ", priorFile)
+ println("Fusion Lambda (λ_f): ", fusionLambda)
+ println("Similarity Source: ", similaritySource)
+ println("===========================================")
+
+ # ── STEP 1: Load expression data [UNCHANGED] ──────────────────────────────
+ data = GeneExpressionData()
+ loadExpressionData!(data, geneExprFile)
+ loadAndFilterTargetGenes!(data, targFile; epsilon=0.01)
+ loadPotentialRegulators!(data, regFile)
+ processTFAGenes!(data, tfaGeneFile; outputDir=dirOut)
+
+ # ── STEP 2: Split by task [NEW] ───────────────────────────────────────────
+ taskLabels = vec(readdlm(taskLabelFile, String)) # one label per sample column
+ mtData = MultitaskExpressionData()
+ splitByTask!(mtData, data, taskLabels)
+
+ # ── STEP 3: Build task similarity graph [NEW] ─────────────────────────────
+ if similaritySource == :metadata && similarityFile !== nothing
+ S = buildSimilarityFromMetadata(mtData.tasks, similarityFile)
+ elseif similaritySource == :ontology && similarityFile !== nothing
+ S = buildSimilarityFromOntology(mtData.tasks, similarityFile)
+ else
+ println("⚠ Using expression-based similarity — see taskSimilarity.jl for caveats.")
+ S = buildSimilarityFromExpression(mtData)
+ end
+ normalizeSimilarity!(S)
+ mtData.taskSimilarity = S
+
+ # ── STEP 4: Merge degenerate TFs [UNCHANGED] ──────────────────────────────
+ mergedTFsData = mergedTFsResult()
+ mergeDegenerateTFs(mergedTFsData, priorFile; fileFormat=2)
+
+ # ── STEP 5: Process prior file [UNCHANGED] ────────────────────────────────
+ # Prior is shared across tasks — one prior file, processed once
+ tfaDataTemplate = PriorTFAData()
+ processPriorFile!(tfaDataTemplate, data, priorFile; mergedTFsData, minTargets=minTargets)
+
+ # ── STEP 6: Per-task TFA estimation [NEW] ────────────────────────────────
+ # Each task gets its own TFA estimate
+ tfaDataVec = [deepcopy(tfaDataTemplate) for _ in mtData.tasks]
+ calculateTFAPerTask!(tfaDataVec, mtData;
+ edgeSS=edgeSS, zTarget=zTarget, outputDir=dirOut)
+
+ # ── STEP 7: Build GRN for each TFA option ────────────────────────────────
+ for tfaOpt in tfaOptions
+ optName = tfaOpt == "" ? "TFA" : "TFmRNA"
+ instabilitiesDir = joinpath(dirOut, optName)
+ mkpath(instabilitiesDir)
+
+ # initialize per-task GrnData
+ mtGrnData = MultitaskGrnData()
+ mtGrnData.fusionLambda = fusionLambda
+ mtGrnData.taskGraph = S
+ for _ in mtData.tasks
+ push!(mtGrnData.taskGrnData, GrnData())
+ end
+
+ # prepare matrices — UNCHANGED per-task logic
+ preparePredictorMatMT!(mtGrnData, mtData, tfaDataVec, tfaOpt)
+ preparePenaltyMatrixMT!(mtData, mtGrnData, priorFilePenalties, lambdaBias, tfaOpt)
+ constructSubsamplesMT!(mtData, mtGrnData; totSS=bstarsTotSS, subsampleFrac=subsampleFrac)
+
+ # coarse warm start — run per task independently [UNCHANGED]
+ for d in 1:length(mtData.tasks)
+ bstarsWarmStart(mtData.taskData[d], tfaDataVec[d], mtGrnData.taskGrnData[d];
+ minLambda=minLambda, maxLambda=maxLambda,
+ totLambdasBstars=totLambdasBstars,
+ targetInstability=targetInstability, zTarget=zTarget)
+ end
+
+ constructSubsamplesMT!(mtData, mtGrnData; totSS=totSS, subsampleFrac=subsampleFrac)
+
+ # instability estimation — ADMM-fused [NEW CORE STEP]
+ bstartsEstimateInstabilityMT!(mtGrnData, mtData;
+ totLambdas = totLambdas,
+ instabilityLevel = instabilityLevel,
+ zTarget = zTarget,
+ outputDir = instabilitiesDir,
+ lambdaOpt = lambdaOpt, # NEW
+ fusionRatio = fusionRatio, # NEW
+ ebicGamma = ebicGamma, # NEW
+ gridSize = gridSize, # NEW
+ refinementSize = refinementSize, # NEW
+ elasticNetAlpha = elasticNetAlpha) # NEW
+
+ # lambda selection and edge ranking — UNCHANGED per-task logic
+ mtBuildGrn = MultitaskBuildGrn()
+ mtBuildGrn.tasks = mtData.tasks
+ for _ in mtData.tasks
+ push!(mtBuildGrn.taskBuildGrn, BuildGrn())
+ end
+
+ chooseLambdaMT!(mtGrnData, mtBuildGrn;
+ instabilityLevel=instabilityLevel,
+ targetInstability=targetInstability)
+
+ rankEdgesMT!(mtData, tfaDataVec, mtGrnData, mtBuildGrn;
+ mergedTFsData=mergedTFsData,
+ useMeanEdgesPerGeneMode=useMeanEdgesPerGeneMode,
+ meanEdgesPerGene=meanEdgesPerGene,
+ correlationWeight=correlationWeight,
+ outputDir=instabilitiesDir)
+
+ writeNetworkTableMT!(mtBuildGrn; outputDir=instabilitiesDir)
+ end
+
+ # ── STEP 8: Combine TFA and TFmRNA networks [UNCHANGED] ──────────────────
+ # Combine per-task within each TFA option, then across options
+ for task in mtData.tasks
+ combinedNetDir = joinpath(dirOut, "Combined", task)
+ mkpath(combinedNetDir)
+ nets2combine = [
+ joinpath(dirOut, "TFA", task, "edges.tsv"),
+ joinpath(dirOut, "TFmRNA", task, "edges.tsv")
+ ]
+ combineGRNs(nets2combine;
+ combineOpt=combineOpt,
+ meanEdgesPerGene=meanEdgesPerGene,
+ useMeanEdgesPerGeneMode=useMeanEdgesPerGeneMode,
+ saveDir=combinedNetDir,
+ saveName=task)
+
+ # re-estimate TFA on combined network [UNCHANGED]
+ netsCombinedSparse = joinpath(combinedNetDir, "combined_$(task)_$(combineOpt)_sp.tsv")
+ taskIdx = findfirst(x -> x == task, mtData.tasks)
+ combineGRNS2(mtData.taskData[taskIdx], mergedTFsData, tfaGeneFile,
+ netsCombinedSparse, edgeSS, minTargets,
+ geneExprFile, targFile, regFile; outputDir=combinedNetDir)
+ end
+
+ println("=== Multitask Inferelator Complete ===")
+end
+
+
+# ── Run ────────────────────────────────────────────────────────────────────────
+runMultitaskInferelator(
+ geneExprFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/pseudobulk/counts_Tfh10_vst.txt",
+ targFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/target_genes/gene_targ_Tfh10.txt",
+ regFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/pot_regs/TF_Tfh10_final.txt",
+ priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv",
+ priorFilePenalties = ["/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"],
+ taskLabelFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/taskLabels.txt", # NEW
+ outputDir = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/MultitaskInferelator/test",
+ fusionLambda = 0.1, # fusion penalty (used by :fixed_ratio)
+ similaritySource = :expression, # or :metadata / :ontology with similarityFile
+ # ── Lambda optimization (pick one) ───────────────────────────────────────
+ lambdaOpt = :fixed_ratio, # cheapest — good for first runs
+ # lambdaOpt = :ebic, # principled — recommended for production
+ # lambdaOpt = :bstars_2d, # strongest — use for benchmarking
+ fusionRatio = 0.1, # for :fixed_ratio — lambda_f = fusionRatio * lambda_s
+ ebicGamma = 1.0, # EBIC gamma — 1.0 recommended when p >> n
+ gridSize = 10, # lambda grid size for :ebic and :bstars_2d
+ refinementSize = 10, # fine grid size for :bstars_2d stage 2
+ # ── Elastic net ──────────────────────────────────────────────────────────
+ elasticNetAlpha = 1.0 # 1.0=pure LASSO | 0.5=elastic net (recommended with demerged TFs)
+)
diff --git a/experimental/MTL/main.jl b/experimental/MTL/main.jl
new file mode 100755
index 0000000..f883862
--- /dev/null
+++ b/experimental/MTL/main.jl
@@ -0,0 +1,228 @@
+cd("/data/miraldiNB/Michael/Scripts/GRN/MultitaskInferelator")
+using Pkg
+Pkg.activate(".")
+using Revise
+include("src/MultitaskInferelator.jl")
+using .MultitaskInferelator
+
+"""
+ runMultitaskInferelator(; kwargs...)
+
+Multitask extension of the original Inferelator pipeline.
+
+Changes vs original runInferelator:
+────────────────────────────────────────────────────────────────────
+UNCHANGED loadExpressionData! load expression matrix
+UNCHANGED loadAndFilterTargetGenes! filter target genes
+UNCHANGED loadPotentialRegulators! load TF list
+UNCHANGED processTFAGenes! set TFA gene set
+UNCHANGED mergeDegenerateTFs merge degenerate TFs
+UNCHANGED processPriorFile! process prior file
+UNCHANGED preparePenaltyMatrix! build penalty matrix (per task)
+UNCHANGED constructSubsamples subsample indices (per task)
+UNCHANGED bstarsWarmStart coarse lambda bounds (per task)
+UNCHANGED chooseLambda! lambda selection (per task)
+UNCHANGED rankEdges! edge ranking (per task)
+UNCHANGED writeNetworkTable! write edge tables (per task)
+UNCHANGED combineGRNs combine TFA/TFmRNA networks
+UNCHANGED combineGRNS2 re-estimate TFA on combined
+
+NEW splitByTask! split expression by cell type
+NEW buildSimilarityFrom* construct task graph
+NEW calculateTFAPerTask! per-task TFA estimation
+NEW bstartsEstimateInstabilityMT ADMM-fused instability estimation
+NEW buildConsensus! consensus network across tasks
+
+New parameters vs original:
+────────────────────────────────────────────────────────────────────
+taskLabelFile path to TSV mapping samples → task names
+fusionLambda λ_f fusion penalty strength (default 0.1)
+similaritySource :metadata, :expression, or :ontology
+similarityFile path to metadata/ontology file (if needed)
+"""
+function runMultitaskInferelator(;
+ geneExprFile::String,
+ targFile::String,
+ regFile::String,
+ priorFile::String,
+ priorFilePenalties::Vector{String},
+ taskLabelFile::String, # NEW — maps samples → tasks
+ tfaGeneFile::String = "",
+ outputDir::String,
+ tfaOptions::Vector{String} = ["", "TFmRNA"],
+ totSS::Int = 80,
+ bstarsTotSS::Int = 5,
+ subsampleFrac::Float64 = 0.68,
+ minLambda::Float64 = 0.01,
+ maxLambda::Float64 = 0.5,
+ totLambdasBstars::Int = 20,
+ totLambdas::Int = 40,
+ targetInstability::Float64 = 0.05,
+ meanEdgesPerGene::Int = 20,
+ correlationWeight::Int = 1,
+ minTargets::Int = 3,
+ edgeSS::Int = 0,
+ lambdaBias::Vector{Float64} = [0.5],
+ instabilityLevel::String = "Gene",
+ useMeanEdgesPerGeneMode::Bool = true,
+ combineOpt::String = "max",
+ zTarget::Bool = true,
+ fusionLambda::Float64 = 0.1, # NEW — fusion penalty
+ similaritySource::Symbol = :expression, # NEW — how to build task graph
+ similarityFile::Union{String, Nothing} = nothing # NEW — metadata/ontology file
+)
+
+ # build output directory
+ subsamplePct = subsampleFrac * 100
+ subsampleStr = isinteger(subsamplePct) ? string(Int(subsamplePct)) : replace(string(subsamplePct), "." => "p")
+ lambdaStr = join(replace.(string.(lambdaBias), "." => "p"), "_")
+ networkBaseName = "MT_" * lowercase(instabilityLevel) * "Lambda" * lambdaStr * "_" *
+ string(totSS) * "totSS_" * string(meanEdgesPerGene) * "tfsPerGene_subsamplePCT" * subsampleStr
+ dirOut = joinpath(outputDir, networkBaseName)
+ mkpath(dirOut)
+
+ println("=== Multitask Inferelator Configuration ===")
+ println("Output Directory: ", dirOut)
+ println("Expression File: ", geneExprFile)
+ println("Task Label File: ", taskLabelFile)
+ println("Prior File: ", priorFile)
+ println("Fusion Lambda (λ_f): ", fusionLambda)
+ println("Similarity Source: ", similaritySource)
+ println("===========================================")
+
+ # ── STEP 1: Load expression data [UNCHANGED] ──────────────────────────────
+ data = GeneExpressionData()
+ loadExpressionData!(data, geneExprFile)
+ loadAndFilterTargetGenes!(data, targFile; epsilon=0.01)
+ loadPotentialRegulators!(data, regFile)
+ processTFAGenes!(data, tfaGeneFile; outputDir=dirOut)
+
+ # ── STEP 2: Split by task [NEW] ───────────────────────────────────────────
+ taskLabels = vec(readdlm(taskLabelFile, String)) # one label per sample column
+ mtData = MultitaskExpressionData()
+ splitByTask!(mtData, data, taskLabels)
+
+ # ── STEP 3: Build task similarity graph [NEW] ─────────────────────────────
+ if similaritySource == :metadata && similarityFile !== nothing
+ S = buildSimilarityFromMetadata(mtData.tasks, similarityFile)
+ elseif similaritySource == :ontology && similarityFile !== nothing
+ S = buildSimilarityFromOntology(mtData.tasks, similarityFile)
+ else
+ println("⚠ Using expression-based similarity — see taskSimilarity.jl for caveats.")
+ S = buildSimilarityFromExpression(mtData)
+ end
+ normalizeSimilarity!(S)
+ mtData.taskSimilarity = S
+
+ # ── STEP 4: Merge degenerate TFs [UNCHANGED] ──────────────────────────────
+ mergedTFsData = mergedTFsResult()
+ mergeDegenerateTFs(mergedTFsData, priorFile; fileFormat=2)
+
+ # ── STEP 5: Process prior file [UNCHANGED] ────────────────────────────────
+ # Prior is shared across tasks — one prior file, processed once
+ tfaDataTemplate = PriorTFAData()
+ processPriorFile!(tfaDataTemplate, data, priorFile; mergedTFsData, minTargets=minTargets)
+
+ # ── STEP 6: Per-task TFA estimation [NEW] ────────────────────────────────
+ # Each task gets its own TFA estimate
+ tfaDataVec = [deepcopy(tfaDataTemplate) for _ in mtData.tasks]
+ calculateTFAPerTask!(tfaDataVec, mtData;
+ edgeSS=edgeSS, zTarget=zTarget, outputDir=dirOut)
+
+ # ── STEP 7: Build GRN for each TFA option ────────────────────────────────
+ for tfaOpt in tfaOptions
+ optName = tfaOpt == "" ? "TFA" : "TFmRNA"
+ instabilitiesDir = joinpath(dirOut, optName)
+ mkpath(instabilitiesDir)
+
+ # initialize per-task GrnData
+ mtGrnData = MultitaskGrnData()
+ mtGrnData.fusionLambda = fusionLambda
+ mtGrnData.taskGraph = S
+ for _ in mtData.tasks
+ push!(mtGrnData.taskGrnData, GrnData())
+ end
+
+ # prepare matrices — UNCHANGED per-task logic
+ preparePredictorMatMT!(mtGrnData, mtData, tfaDataVec, tfaOpt)
+ preparePenaltyMatrixMT!(mtData, mtGrnData, priorFilePenalties, lambdaBias, tfaOpt)
+ constructSubsamplesMT!(mtData, mtGrnData; totSS=bstarsTotSS, subsampleFrac=subsampleFrac)
+
+ # coarse warm start — run per task independently [UNCHANGED]
+ for d in 1:length(mtData.tasks)
+ bstarsWarmStart(mtData.taskData[d], tfaDataVec[d], mtGrnData.taskGrnData[d];
+ minLambda=minLambda, maxLambda=maxLambda,
+ totLambdasBstars=totLambdasBstars,
+ targetInstability=targetInstability, zTarget=zTarget)
+ end
+
+ constructSubsamplesMT!(mtData, mtGrnData; totSS=totSS, subsampleFrac=subsampleFrac)
+
+ # instability estimation — ADMM-fused [NEW CORE STEP]
+ bstartsEstimateInstabilityMT!(mtGrnData, mtData;
+ totLambdas=totLambdas,
+ instabilityLevel=instabilityLevel,
+ zTarget=zTarget,
+ outputDir=instabilitiesDir)
+
+ # lambda selection and edge ranking — UNCHANGED per-task logic
+ mtBuildGrn = MultitaskBuildGrn()
+ mtBuildGrn.tasks = mtData.tasks
+ for _ in mtData.tasks
+ push!(mtBuildGrn.taskBuildGrn, BuildGrn())
+ end
+
+ chooseLambdaMT!(mtGrnData, mtBuildGrn;
+ instabilityLevel=instabilityLevel,
+ targetInstability=targetInstability)
+
+ rankEdgesMT!(mtData, tfaDataVec, mtGrnData, mtBuildGrn;
+ mergedTFsData=mergedTFsData,
+ useMeanEdgesPerGeneMode=useMeanEdgesPerGeneMode,
+ meanEdgesPerGene=meanEdgesPerGene,
+ correlationWeight=correlationWeight,
+ outputDir=instabilitiesDir)
+
+ writeNetworkTableMT!(mtBuildGrn; outputDir=instabilitiesDir)
+ end
+
+ # ── STEP 8: Combine TFA and TFmRNA networks [UNCHANGED] ──────────────────
+ # Combine per-task within each TFA option, then across options
+ for task in mtData.tasks
+ combinedNetDir = joinpath(dirOut, "Combined", task)
+ mkpath(combinedNetDir)
+ nets2combine = [
+ joinpath(dirOut, "TFA", task, "edges.tsv"),
+ joinpath(dirOut, "TFmRNA", task, "edges.tsv")
+ ]
+ combineGRNs(nets2combine;
+ combineOpt=combineOpt,
+ meanEdgesPerGene=meanEdgesPerGene,
+ useMeanEdgesPerGeneMode=useMeanEdgesPerGeneMode,
+ saveDir=combinedNetDir,
+ saveName=task)
+
+ # re-estimate TFA on combined network [UNCHANGED]
+ netsCombinedSparse = joinpath(combinedNetDir, "combined_$(task)_$(combineOpt)_sp.tsv")
+ taskIdx = findfirst(x -> x == task, mtData.tasks)
+ combineGRNS2(mtData.taskData[taskIdx], mergedTFsData, tfaGeneFile,
+ netsCombinedSparse, edgeSS, minTargets,
+ geneExprFile, targFile, regFile; outputDir=combinedNetDir)
+ end
+
+ println("=== Multitask Inferelator Complete ===")
+end
+
+
+# ── Run ────────────────────────────────────────────────────────────────────────
+runMultitaskInferelator(
+ geneExprFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/pseudobulk/counts_Tfh10_vst.txt",
+ targFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/target_genes/gene_targ_Tfh10.txt",
+ regFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/pot_regs/TF_Tfh10_final.txt",
+ priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv",
+ priorFilePenalties = ["/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"],
+ taskLabelFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/taskLabels.txt", # NEW
+ outputDir = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/MultitaskInferelator/test",
+ fusionLambda = 0.1, # NEW — tune this
+ similaritySource = :expression # NEW — or :metadata with similarityFile
+)
diff --git a/experimental/MTL/multitaskData.jl b/experimental/MTL/multitaskData.jl
new file mode 100755
index 0000000..079a329
--- /dev/null
+++ b/experimental/MTL/multitaskData.jl
@@ -0,0 +1,96 @@
+"""
+multitaskData.jl [NEW FILE]
+
+Defines the data structures needed for multitask inference.
+The key addition over the original codebase is splitting one expression
+matrix into per-task views, and storing per-task GRN results.
+"""
+
+# ── Task-split expression data ────────────────────────────────────────────────
+"""
+ MultitaskExpressionData
+
+Holds per-task views of the expression data. Each task corresponds to
+a biological context (e.g. cell type) defined by the column labels of
+the original expression matrix.
+
+Fields
+──────
+- `tasks` : names of each task (e.g. ["Tfh", "Th1", "Treg"])
+- `taskData` : one GeneExpressionData per task [per-task]
+- `taskSimilarity` : tasks × tasks similarity matrix [NEW]
+- `taskLabels` : maps each sample column → task name [NEW]
+"""
+mutable struct MultitaskExpressionData
+ tasks::Vector{String}
+ taskData::Vector{GeneExpressionData}
+ taskSimilarity::Matrix{Float64}
+ taskLabels::Vector{String}
+
+ function MultitaskExpressionData()
+ return new(
+ String[],
+ GeneExpressionData[],
+ Matrix{Float64}(undef, 0, 0),
+ String[]
+ )
+ end
+end
+
+
+# ── Per-task GRN data ─────────────────────────────────────────────────────────
+"""
+ MultitaskGrnData
+
+Wraps one GrnData per task plus the fusion penalty parameter.
+The per-task GrnData structs are identical to the original codebase —
+the only addition is `fusionLambda` which controls how strongly
+related tasks are pulled toward each other.
+
+Fields
+──────
+- `taskGrnData` : one GrnData per task [per-task]
+- `fusionLambda` : λ_f — fusion penalty strength [NEW]
+- `taskGraph` : tasks × tasks similarity (copied from MT data for convenience)
+"""
+mutable struct MultitaskGrnData
+ taskGrnData::Vector{GrnData}
+ fusionLambda::Float64
+ taskGraph::Matrix{Float64}
+
+ function MultitaskGrnData()
+ return new(
+ GrnData[],
+ 0.1,
+ Matrix{Float64}(undef, 0, 0)
+ )
+ end
+end
+
+
+# ── Per-task BuildGrn ─────────────────────────────────────────────────────────
+"""
+ MultitaskBuildGrn
+
+Holds one BuildGrn per task plus the consensus network
+(averaged across tasks after fusion).
+
+Fields
+──────
+- `taskBuildGrn` : one BuildGrn per task [per-task]
+- `consensusNetwork` : gene × TF matrix averaged across tasks [NEW]
+- `tasks` : task names for indexing [NEW]
+"""
+mutable struct MultitaskBuildGrn
+ taskBuildGrn::Vector{BuildGrn}
+ consensusNetwork::Matrix{Float64}
+ tasks::Vector{String}
+
+ function MultitaskBuildGrn()
+ return new(
+ BuildGrn[],
+ Matrix{Float64}(undef, 0, 0),
+ String[]
+ )
+ end
+end
diff --git a/experimental/MTL/prepareMultitask.jl b/experimental/MTL/prepareMultitask.jl
new file mode 100755
index 0000000..0b71e6f
--- /dev/null
+++ b/experimental/MTL/prepareMultitask.jl
@@ -0,0 +1,272 @@
+"""
+prepareMultitask.jl [NEW FILE]
+
+Multitask versions of the preparation functions from prepareGRN.jl.
+
+What stays the same vs original:
+─────────────────────────────────────────────────────────────────────
+UNCHANGED preparePredictorMat! — called once per task
+UNCHANGED preparePenaltyMatrix! — called once per task
+UNCHANGED constructSubsamples — called once per task
+UNCHANGED bstarsWarmStart — called once per task (coarse pass)
+NEW splitByTask! — splits expression matrix by task
+NEW preparePredictorMatMT! — loops preparePredictorMat! over tasks
+NEW preparePenaltyMatrixMT! — loops preparePenaltyMatrix! over tasks
+NEW constructSubsamplesMT! — loops constructSubsamples over tasks
+NEW bstartsEstimateInstabilityMT — replaces bstartsEstimateInstability,
+ uses ADMM instead of GLMNet per gene
+"""
+
+
+# ── NEW: Split expression matrix by task ─────────────────────────────────────
+"""
+ splitByTask!(mtData::MultitaskExpressionData, data::GeneExpressionData, taskLabels::Vector{String})
+
+Split one GeneExpressionData into per-task views based on column labels.
+
+`taskLabels` must have the same length as `data.cellLabels` and maps
+each sample column to a task name (e.g. "Tfh", "Th1", "Treg").
+
+This is the only entry point that creates per-task data — everything
+downstream just iterates over `mtData.taskData`.
+"""
+function splitByTask!(mtData::MultitaskExpressionData,
+ data::GeneExpressionData,
+ taskLabels::Vector{String})
+
+ if length(taskLabels) != length(data.cellLabels)
+ error("taskLabels length ($(length(taskLabels))) must match number of samples ($(length(data.cellLabels)))")
+ end
+
+ uniqueTasks = unique(taskLabels)
+ mtData.taskLabels = taskLabels
+ mtData.tasks = uniqueTasks
+
+ for task in uniqueTasks
+ taskInds = findall(x -> x == task, taskLabels)
+
+ td = GeneExpressionData()
+ # shared metadata — same across all tasks
+ td.geneNames = data.geneNames
+ td.targGenes = data.targGenes
+ td.potRegs = data.potRegs
+ td.potRegsmRNA = data.potRegsmRNA
+ td.tfaGenes = data.tfaGenes
+ # per-task column subsets
+ td.cellLabels = data.cellLabels[taskInds]
+ td.geneExpressionMat = data.geneExpressionMat[:, taskInds]
+ td.targGeneMat = data.targGeneMat[:, taskInds]
+ td.potRegMatmRNA = data.potRegMatmRNA[:, taskInds]
+ td.tfaGeneMat = data.tfaGeneMat[:, taskInds]
+
+ push!(mtData.taskData, td)
+ end
+
+ println("Split into $(length(uniqueTasks)) tasks: ", join(uniqueTasks, ", "))
+end
+
+
+# ── NEW: Per-task TFA estimation ──────────────────────────────────────────────
+"""
+ calculateTFAPerTask!(tfaDataVec, mtData; edgeSS=0, zTarget=false, outputDir=nothing)
+
+Run calculateTFA! independently for each task.
+Returns a Vector{PriorTFAData}, one per task.
+
+Each task gets its own TFA estimate because TF activity can differ
+substantially across cell types — pooling them would obscure this.
+"""
+function calculateTFAPerTask!(tfaDataVec::Vector{PriorTFAData},
+ mtData::MultitaskExpressionData;
+ edgeSS::Int = 0,
+ zTarget::Bool = false,
+ outputDir::Union{String, Nothing} = nothing)
+
+ for (d, taskName) in enumerate(mtData.tasks)
+ taskDir = outputDir !== nothing ? joinpath(outputDir, taskName) : nothing
+ if taskDir !== nothing
+ mkpath(taskDir)
+ end
+ calculateTFA!(tfaDataVec[d], mtData.taskData[d];
+ edgeSS = edgeSS, zscoreTargExp = zTarget, outputDir = taskDir)
+ println("TFA estimated for task: ", taskName)
+ end
+end
+
+
+# ── NEW: Loop prepare functions over tasks ────────────────────────────────────
+"""
+ preparePredictorMatMT!(mtGrnData, mtData, tfaDataVec, tfaOpt)
+
+Call preparePredictorMat! for each task. [UNCHANGED per-task logic]
+"""
+function preparePredictorMatMT!(mtGrnData::MultitaskGrnData,
+ mtData::MultitaskExpressionData,
+ tfaDataVec::Vector{PriorTFAData},
+ tfaOpt::String)
+ for d in 1:length(mtData.tasks)
+ preparePredictorMat!(mtGrnData.taskGrnData[d], mtData.taskData[d], tfaDataVec[d], tfaOpt)
+ end
+end
+
+
+"""
+ preparePenaltyMatrixMT!(mtData, mtGrnData, priorFilePenalties, lambdaBias, tfaOpt)
+
+Call preparePenaltyMatrix! for each task. [UNCHANGED per-task logic]
+"""
+function preparePenaltyMatrixMT!(mtData::MultitaskExpressionData,
+ mtGrnData::MultitaskGrnData,
+ priorFilePenalties::Vector{String},
+ lambdaBias::Vector{Float64},
+ tfaOpt::String)
+ for d in 1:length(mtData.tasks)
+ preparePenaltyMatrix!(mtData.taskData[d], mtGrnData.taskGrnData[d],
+ priorFilePenalties, lambdaBias, tfaOpt)
+ end
+end
+
+
+"""
+ constructSubsamplesMT!(mtData, mtGrnData; totSS, subsampleFrac)
+
+Call constructSubsamples for each task. [UNCHANGED per-task logic]
+"""
+function constructSubsamplesMT!(mtData::MultitaskExpressionData,
+ mtGrnData::MultitaskGrnData;
+ totSS::Int = 50,
+ subsampleFrac::Float64 = 0.68)
+ for d in 1:length(mtData.tasks)
+ constructSubsamples(mtData.taskData[d], mtGrnData.taskGrnData[d];
+ totSS = totSS, subsampleFrac = subsampleFrac)
+ end
+end
+
+
+# ── NEW: Multitask instability estimation (replaces bstartsEstimateInstability) ──
+"""
+ bstartsEstimateInstabilityMT!(mtGrnData, mtData;
+ totLambdas, instabilityLevel, zTarget, outputDir)
+
+Replaces bstartsEstimateInstability from the original codebase.
+
+Key difference: instead of fitting GLMNet independently per task,
+this calls admm_fused_lasso which couples related tasks via λ_f.
+
+The per-gene parallelism (Threads.@threads) is preserved — tasks are
+coupled WITHIN each gene's ADMM problem, not across genes.
+
+Instability is computed the same way as the original:
+ θ = (1/totSS) * edge_selection_count
+ instab = 2 * θ * (1 - θ)
+But now per task, giving a tasks × lambdas instability surface per gene.
+"""
+function bstartsEstimateInstabilityMT!(mtGrnData::MultitaskGrnData,
+ mtData::MultitaskExpressionData;
+ totLambdas::Int = 10,
+ instabilityLevel::String = "Gene",
+ zTarget::Bool = false,
+ outputDir::Union{String, Nothing} = nothing)
+
+ nTasks = length(mtData.tasks)
+ # use first task to get dimensions — all tasks share same gene/TF sets
+ refGrn = mtGrnData.taskGrnData[1]
+ totResponses, totSamps = size(refGrn.responseMat)
+ totPreds = size(refGrn.predictorMat, 1)
+ totSS = size(refGrn.subsamps, 1)
+
+ # Lambda range: use network-level bounds from first task as reference
+ # TODO: consider per-task lambda ranges for heterogeneous tasks
+ minLambda = minimum([g.minLambdaNet for g in mtGrnData.taskGrnData])
+ maxLambda = maximum([g.maxLambdaNet for g in mtGrnData.taskGrnData])
+ lambdaRange = reverse(collect(range(minLambda, stop=maxLambda, length=totLambdas)))
+
+ # Store edge selection counts: lambdas × genes × TFs × tasks
+ ssMatrix = Inf * ones(totLambdas, totResponses, totPreds, nTasks)
+ betas = Array{Float64, 4}(undef, totResponses, totPreds, totLambdas, nTasks)
+
+ # get finite predictor indices per task per response
+ responsePredInds = [[findall(x -> x != Inf, mtGrnData.taskGrnData[d].penaltyMat[res, :])
+ for res in 1:totResponses]
+ for d in 1:nTasks]
+
+ Threads.@threads for res in ProgressBar(1:totResponses)
+ for ss in 1:totSS
+ # collect per-task predictors and responses for this subsample
+ taskPredMats = Matrix{Float64}[]
+ taskRespVecs = Vector{Float64}[]
+ taskPenalties = Vector{Float64}[]
+
+ for d in 1:nTasks
+ subsamp = mtGrnData.taskGrnData[d].subsamps[ss, :]
+ predInds = responsePredInds[d][res]
+
+ dt = fit(ZScoreTransform,
+ mtGrnData.taskGrnData[d].predictorMat[predInds, subsamp], dims=2)
+ currPreds = transpose(StatsBase.transform(dt,
+ mtGrnData.taskGrnData[d].predictorMat[predInds, subsamp]))
+
+ if zTarget
+ dt2 = fit(ZScoreTransform,
+ mtGrnData.taskGrnData[d].responseMat[res, subsamp], dims=1)
+ currResp = StatsBase.transform(dt2,
+ mtGrnData.taskGrnData[d].responseMat[res, subsamp])
+ else
+ currResp = mtGrnData.taskGrnData[d].responseMat[res, subsamp]
+ end
+
+ push!(taskPredMats, transpose(currPreds)) # TFs × subsampled_cells
+ push!(taskRespVecs, vec(currResp))
+ push!(taskPenalties, mtGrnData.taskGrnData[d].penaltyMat[res, predInds])
+ end
+
+ # ── ADMM: couples tasks via fusion penalty ─────────────────────────
+ # This is the key difference from the original codebase.
+ # Original: glmnet(currPreds, currResponses, ...) per task independently
+ # New: admm_fused_lasso(...) jointly over all tasks
+ betasByLambda = admmWarmStart(
+ taskPredMats, taskRespVecs, taskPenalties,
+ mtGrnData.taskGraph, lambdaRange, mtGrnData.fusionLambda
+ ) # TFs × tasks × lambdas
+
+ for d in 1:nTasks
+ predInds = responsePredInds[d][res]
+ for (li, _) in enumerate(lambdaRange)
+ ssMatrix[li, res, predInds, d] .+= abs.(sign.(betasByLambda[:, d, li]))
+ betas[res, predInds, li, d] = betasByLambda[:, d, li]
+ end
+ end
+ end
+ end
+
+ # compute instabilities per task (same formula as original)
+ for d in 1:nTasks
+ grnD = mtGrnData.taskGrnData[d]
+ geneInstabilities = zeros(totResponses, totLambdas)
+
+ for res in 1:totResponses
+ predInds = responsePredInds[d][res]
+ ssVals = ssMatrix[:, res, predInds, d]
+ theta2 = (1 / totSS) * ssVals
+ instabPerEdge = 2 * (theta2 .* (1 .- theta2))
+ aveInstab = vec(mean(instabPerEdge, dims=2))
+ maxUb = maximum(aveInstab)
+ maxUbInd = findlast(x -> x == maxUb, aveInstab)
+ aveInstab[maxUbInd:end] .= maxUb
+ geneInstabilities[res, :] = aveInstab
+ end
+
+ grnD.geneInstabilities = geneInstabilities
+ grnD.lambdaRange = lambdaRange
+ grnD.stabilityMat = ssMatrix[:, :, :, d]
+ grnD.betas = betas[:, :, :, d]
+
+ if outputDir !== nothing
+ taskDir = joinpath(outputDir, mtData.tasks[d])
+ mkpath(taskDir)
+ save_object(joinpath(taskDir, "instabOutMat.jld"), grnD)
+ end
+ end
+
+ println("Multitask instability estimation complete.")
+end
diff --git a/experimental/MTL/taskSimilarity.jl b/experimental/MTL/taskSimilarity.jl
new file mode 100755
index 0000000..9cd65e8
--- /dev/null
+++ b/experimental/MTL/taskSimilarity.jl
@@ -0,0 +1,138 @@
+"""
+taskSimilarity.jl [NEW FILE]
+
+Constructs the task similarity graph used as the fusion penalty structure.
+The graph determines which tasks are pulled toward each other by λ_f.
+
+Key design principle: similarity should come from an INDEPENDENT source,
+not from the same expression data used to fit the model (circular).
+Three options are provided in order of preference.
+"""
+
+
+"""
+ buildSimilarityFromMetadata(tasks, metadataFile; connector="_")
+
+Build task similarity from a user-supplied metadata file.
+This is the preferred approach — avoids circular use of expression data.
+
+The metadata file should be a TSV with columns: Task, Group
+where Group defines which tasks are biologically related.
+Tasks in the same group get similarity = 1, across groups = 0.
+
+# Arguments
+- `tasks` : task names in order matching MultitaskExpressionData.tasks
+- `metadataFile` : path to TSV metadata file
+- `connector` : separator used in task names (default "_")
+"""
+function buildSimilarityFromMetadata(tasks::Vector{String}, metadataFile::String)
+ df = CSV.read(metadataFile, DataFrame; delim='\t')
+ nTasks = length(tasks)
+ S = zeros(Float64, nTasks, nTasks)
+
+ taskToGroup = Dict(row.Task => row.Group for row in eachrow(df))
+
+ for i in 1:nTasks
+ for j in 1:nTasks
+ gi = get(taskToGroup, tasks[i], nothing)
+ gj = get(taskToGroup, tasks[j], nothing)
+ if gi !== nothing && gj !== nothing && gi == gj
+ S[i, j] = 1.0
+ end
+ end
+ end
+ # diagonal is always 1
+ for i in 1:nTasks
+ S[i, i] = 1.0
+ end
+ return S
+end
+
+
+"""
+ buildSimilarityFromExpression(mtData::MultitaskExpressionData; method=:pearson)
+
+Build task similarity from mean expression profiles per task.
+
+⚠️ WARNING: This is circular — you are using the same expression data
+to define task similarity AND to fit the model. Only use this as a
+last resort when no independent metadata is available. Consider
+using only a held-out subset of genes (e.g. housekeeping genes)
+for the similarity calculation.
+
+# Arguments
+- `mtData` : MultitaskExpressionData with taskData populated
+- `method` : :pearson (default) or :spearman
+"""
+function buildSimilarityFromExpression(mtData::MultitaskExpressionData; method::Symbol=:pearson)
+ nTasks = length(mtData.tasks)
+ # compute mean expression profile per task
+ meanProfiles = hcat([vec(mean(td.targGeneMat, dims=2)) for td in mtData.taskData]...) # genes × tasks
+ S = zeros(Float64, nTasks, nTasks)
+
+ for i in 1:nTasks
+ for j in 1:nTasks
+ if method == :pearson
+ S[i, j] = cor(meanProfiles[:, i], meanProfiles[:, j])
+ elseif method == :spearman
+ ri = tiedrank(meanProfiles[:, i])
+ rj = tiedrank(meanProfiles[:, j])
+ S[i, j] = cor(ri, rj)
+ end
+ end
+ end
+
+ # clip to [0, 1] — negative correlations treated as no similarity
+ S = max.(S, 0.0)
+ return S
+end
+
+
+"""
+ buildSimilarityFromOntology(tasks, ontologyFile)
+
+Build task similarity from a cell type ontology distance file.
+The ontology file should be a TSV with columns: Task1, Task2, Distance
+where Distance is a non-negative value (0 = identical, larger = more distant).
+
+Similarity is computed as sim = exp(-distance / scale) where scale
+is the median pairwise distance.
+"""
+function buildSimilarityFromOntology(tasks::Vector{String}, ontologyFile::String)
+ df = CSV.read(ontologyFile, DataFrame; delim='\t')
+ nTasks = length(tasks)
+ distMat = zeros(Float64, nTasks, nTasks)
+ taskIdx = Dict(t => i for (i, t) in enumerate(tasks))
+
+ for row in eachrow(df)
+ if haskey(taskIdx, row.Task1) && haskey(taskIdx, row.Task2)
+ i, j = taskIdx[row.Task1], taskIdx[row.Task2]
+ distMat[i, j] = row.Distance
+ distMat[j, i] = row.Distance
+ end
+ end
+
+ # convert distance to similarity via RBF kernel
+ offDiag = [distMat[i,j] for i in 1:nTasks for j in 1:nTasks if i != j]
+ scale = isempty(offDiag) ? 1.0 : median(offDiag)
+ S = exp.(-distMat ./ max(scale, 1e-8))
+ return S
+end
+
+
+"""
+ normalizeSimilarity!(S::Matrix{Float64})
+
+Row-normalize similarity matrix so rows sum to 1.
+This ensures the fusion penalty is scale-invariant across tasks
+with different numbers of neighbors.
+"""
+function normalizeSimilarity!(S::Matrix{Float64})
+ for i in 1:size(S, 1)
+ rowSum = sum(S[i, :])
+ if rowSum > 0
+ S[i, :] ./= rowSum
+ end
+ end
+ return S
+end
diff --git a/experimental/MTL/~$MultitaskGRN_equations.pptx b/experimental/MTL/~$MultitaskGRN_equations.pptx
new file mode 100755
index 0000000..92f5f34
Binary files /dev/null and b/experimental/MTL/~$MultitaskGRN_equations.pptx differ
diff --git a/pipeline_diagram.png b/pipeline_diagram.png
new file mode 100644
index 0000000..fd9c761
Binary files /dev/null and b/pipeline_diagram.png differ
diff --git a/src/API.jl b/src/API.jl
new file mode 100755
index 0000000..66d4f9f
--- /dev/null
+++ b/src/API.jl
@@ -0,0 +1,418 @@
+# =============================================================================
+# InferelatorJL — Public API
+# src/API.jl
+#
+# High-level wrapper functions that compose the internal !-mutating calls
+# into clean, single-responsibility entry points.
+# All heavy lifting stays in the submodule files; this layer only
+# orchestrates and exposes a stable interface.
+# =============================================================================
+
+
+# -----------------------------------------------------------------------------
+# STEP 1 · Load & filter all expression data
+# -----------------------------------------------------------------------------
+"""
+ loadData(exprFile, targFile, regFile; tfaGeneFile="", epsilon=0.01)
+
+Load and filter all expression inputs into a `GeneExpressionData` struct.
+
+# Arguments
+- `exprFile` : Path to gene expression matrix (TSV or Arrow, genes × samples)
+- `targFile` : Path to target gene list (genes to model as responses)
+- `regFile` : Path to potential regulator list (candidate TFs)
+- `tfaGeneFile` : Optional path to gene list used for TFA estimation
+- `epsilon` : Minimum variance threshold for target gene filtering (default 0.01)
+
+# Returns
+`GeneExpressionData` with fields populated:
+`geneExpressionMat`, `targGeneMat`, `potRegMatmRNA`, `tfaGeneMat`
+"""
+function loadData(
+ exprFile::String,
+ targFile::String,
+ regFile::String;
+ tfaGeneFile::String = "",
+ epsilon::Float64 = 0.01
+)::GeneExpressionData
+
+ data = GeneExpressionData()
+ loadExpressionData!(data, exprFile)
+ loadAndFilterTargetGenes!(data, targFile; epsilon = epsilon)
+ loadPotentialRegulators!(data, regFile)
+ processTFAGenes!(data, tfaGeneFile)
+ return data
+end
+
+
+# -----------------------------------------------------------------------------
+# STEP 2 + 3 · Merge degenerate TFs then process prior & estimate TFA
+# -----------------------------------------------------------------------------
+"""
+ loadPrior(data, priorFile; minTargets=3, mergeDegenerate=true, connector="_")
+
+Merge degenerate TFs, process the prior matrix, and return both the
+`PriorTFAData` struct and the `mergedTFsResult` map.
+
+# Arguments
+- `data` : Populated `GeneExpressionData` from `loadData`
+- `priorFile` : Path to TF × Gene prior matrix (TSV)
+- `minTargets` : Minimum number of targets a TF must have to be retained (default 3)
+- `mergeDegenerate` : Whether to collapse TFs with identical target sets (default true)
+- `connector` : String used to join meta-TF names (default "_")
+
+# Returns
+Tuple `(PriorTFAData, mergedTFsResult)`
+"""
+function loadPrior(
+ data::GeneExpressionData,
+ priorFile::String;
+ minTargets::Int = 3,
+ mergeDegenerate::Bool = true,
+ connector::String = "_"
+)::Tuple{PriorTFAData, mergedTFsResult}
+
+ mergedTFs = mergedTFsResult()
+ if mergeDegenerate
+ mergeDegenerateTFs(mergedTFs, priorFile; fileFormat = 2, connector = connector)
+ end
+
+ priorData = PriorTFAData()
+ processPriorFile!(priorData, data, priorFile;
+ mergedTFsData = mergedTFs,
+ minTargets = minTargets)
+ return priorData, mergedTFs
+end
+
+
+# -----------------------------------------------------------------------------
+# STEP 3 (cont.) · Estimate TF activity
+# -----------------------------------------------------------------------------
+"""
+ estimateTFA(priorData, data; edgeSS=0, zScoreTFA=true, outputDir=".")
+
+Estimate TF activity (TFA) by solving `prior * TFA ≈ targetExpression`
+via least squares, with optional bootstrap subsampling of targets.
+
+# Arguments
+- `priorData` : `PriorTFAData` from `loadPrior`
+- `data` : `GeneExpressionData` from `loadData`
+- `edgeSS` : Number of edge subsamples (0 = no subsampling)
+- `zScoreTFA` : Z-score target expression before solving TFA (default true)
+- `outputDir` : Directory for intermediate output files (default ".")
+
+# Returns
+TFA matrix as `Matrix{Float64}` (TFs × samples), also stored in `priorData.medTfas`
+"""
+function estimateTFA(
+ priorData::PriorTFAData,
+ data::GeneExpressionData;
+ edgeSS::Int = 0,
+ zScoreTFA::Bool = true,
+ outputDir::String = "."
+)::Matrix{Float64}
+
+ calculateTFA!(priorData, data;
+ edgeSS = edgeSS,
+ zTarget = zScoreTFA,
+ outputDir = outputDir)
+ return priorData.medTfas
+end
+
+
+# -----------------------------------------------------------------------------
+# STEP 4 · Build a single GRN (one predictor mode)
+# -----------------------------------------------------------------------------
+"""
+ buildNetwork(data, priorData; kwargs...) → BuildGrn
+
+Run the full mLASSO-StARS pipeline for one predictor mode and return the
+ranked edge table.
+
+Internally runs:
+`preparePredictorMat!` → `preparePenaltyMatrix!` → `constructSubsamples` →
+`bstarsWarmStart` → `constructSubsamples` (full) → `bstartsEstimateInstability` →
+`chooseLambda!` → `rankEdges!` → `writeNetworkTable!`
+
+# Arguments
+- `data` : `GeneExpressionData`
+- `priorData` : `PriorTFAData`
+- `tfaMode` : true = TFA predictors, false = raw mRNA for all TFs
+- `priorFilePenalties` : Prior file(s) used to build the penalty matrix
+- `lambdaBias` : Penalty reduction factor for prior edges (default [0.5])
+- `totSS` : Total subsamples for fine instability estimation (default 80)
+- `bstarsTotSS` : Subsamples for warm-start (default 5)
+- `subsampleFrac` : Fraction of samples per subsample (default 0.63)
+- `minLambda` : Lower bound of λ search range (default 0.01)
+- `maxLambda` : Upper bound of λ search range (default 0.5)
+- `totLambdasBstars` : λ grid points in warm-start pass (default 20)
+- `totLambdas` : λ grid points in fine estimation pass (default 40)
+- `targetInstability` : Instability threshold for λ selection (default 0.05)
+- `meanEdgesPerGene` : Max edges retained per target gene (default 20)
+- `correlationWeight` : Weight of partial correlation in edge scoring (default 1)
+- `instabilityLevel` : "Network" (single λ) or "Gene" (per-gene λ)
+- `useMeanEdgesPerGeneMode`: Enforce per-gene edge cap (default true)
+- `zTarget` : Z-score targets during regression (default true)
+- `outputDir` : Output directory for edges.tsv and stability arrays
+
+# Returns
+`BuildGrn` with `networkStability`, `signedQuantile`, `networkMat`, etc.
+"""
+function buildNetwork(
+ data::GeneExpressionData,
+ priorData::PriorTFAData;
+ tfaMode::Bool = true,
+ priorFilePenalties::Vector{String} = String[],
+ lambdaBias::Vector{Float64} = [0.5],
+ totSS::Int = 80,
+ bstarsTotSS::Int = 5,
+ subsampleFrac::Float64 = 0.63,
+ minLambda::Float64 = 0.01,
+ maxLambda::Float64 = 0.5,
+ totLambdasBstars::Int = 20,
+ totLambdas::Int = 40,
+ targetInstability::Float64 = 0.05,
+ meanEdgesPerGene::Int = 20,
+ correlationWeight::Int = 1,
+ instabilityLevel::String = "Network",
+ useMeanEdgesPerGeneMode::Bool = true,
+ zScoreLASSO::Bool = true,
+ outputDir::String = "."
+)::BuildGrn
+
+ tfaOpt = tfaMode ? "" : "TFmRNA"
+
+ grnData = GrnData()
+ preparePredictorMat!(grnData, data, priorData; tfaOpt = tfaOpt)
+ preparePenaltyMatrix!(data, grnData;
+ priorFilePenalties = priorFilePenalties,
+ lambdaBias = lambdaBias,
+ tfaOpt = tfaOpt)
+
+ # Warm-start pass (coarse λ range)
+ constructSubsamples(data, grnData; totSS = bstarsTotSS, subsampleFrac = subsampleFrac)
+ bstarsWarmStart(data, priorData, grnData;
+ minLambda = minLambda,
+ maxLambda = maxLambda,
+ totLambdasBstars = totLambdasBstars,
+ targetInstability = targetInstability,
+ zTarget = zScoreLASSO)
+
+ # Fine estimation pass
+ constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = subsampleFrac)
+ bstartsEstimateInstability(grnData;
+ totLambdas = totLambdas,
+ instabilityLevel = instabilityLevel,
+ zTarget = zScoreLASSO,
+ outputDir = outputDir)
+
+ buildGrn = BuildGrn()
+ chooseLambda!(grnData, buildGrn;
+ instabilityLevel = instabilityLevel,
+ targetInstability = targetInstability)
+ rankEdges!(data, priorData, grnData, buildGrn;
+ useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode,
+ meanEdgesPerGene = meanEdgesPerGene,
+ correlationWeight = correlationWeight,
+ outputDir = outputDir)
+ writeNetworkTable!(buildGrn; outputDir = outputDir)
+
+ return buildGrn
+end
+
+
+# -----------------------------------------------------------------------------
+# STEP 6 · Recalculate TFA using the combined network as a refined prior
+# -----------------------------------------------------------------------------
+"""
+ refineTFA(combinedNetFile, data, mergedTFs; kwargs...) → Matrix{Float64}
+
+Use the aggregated consensus network as a new prior and re-estimate TF
+activity, yielding a data-driven TFA matrix that reflects the final GRN.
+
+# Arguments
+- `combinedNetFile` : Path to sparse combined network TSV (combined_max_sp.tsv)
+- `data` : `GeneExpressionData`
+- `mergedTFs` : `mergedTFsResult` from `loadPrior`
+- `tfaGeneFile` : Optional gene list for TFA (default "")
+- `edgeSS` : Edge subsampling replicates for TFA (default 0)
+- `minTargets` : Minimum targets per TF (default 3)
+- `zScoreTFA` : Z-score target expression before solving TFA (default true)
+- `exprFile` : Original expression file path
+- `targFile` : Target gene file path
+- `regFile` : Regulator file path
+- `outputDir` : Directory for refined TFA output (default ".")
+
+# Returns
+Refined TFA matrix as `Matrix{Float64}` (regulators × samples)
+"""
+function refineTFA(
+ combinedNetFile::String,
+ data::GeneExpressionData,
+ mergedTFs::mergedTFsResult;
+ tfaGeneFile::String = "",
+ edgeSS::Int = 0,
+ minTargets::Int = 3,
+ zScoreTFA::Bool = true,
+ exprFile::String = "",
+ targFile::String = "",
+ regFile::String = "",
+ outputDir::String = "."
+)
+
+ # Dispatch to internal refineTFA(data::GeneExpressionData, ...) — different first-arg type
+ refineTFA(data, mergedTFs;
+ priorFile = combinedNetFile,
+ tfaGeneFile = tfaGeneFile,
+ edgeSS = edgeSS,
+ minTargets = minTargets,
+ zTarget = zScoreTFA,
+ geneExprFile = exprFile,
+ targFile = targFile,
+ regFile = regFile,
+ outputDir = outputDir)
+end
+
+
+# -----------------------------------------------------------------------------
+# STEP 7 · Evaluate a network against a gold standard
+# -----------------------------------------------------------------------------
+"""
+ evaluateNetwork(networkFile, goldStandard; metric=:AUPR) → NamedTuple
+
+Evaluate a ranked edge list against a gold-standard network.
+
+# Arguments
+- `networkFile` : Path to edges TSV (TF, Gene, signedQuantile, ...)
+- `goldStandard` : Path to gold-standard network TSV
+- `metric` : Evaluation metric — `:AUPR`, `:AUROC`, or `:both` (default `:AUPR`)
+
+# Returns
+`NamedTuple` with fields depending on `metric`:
+`(aupr=..., auroc=..., precision=..., recall=...)`
+"""
+function evaluateNetwork(
+ networkFile::String,
+ goldStandard::String;
+ metric::Symbol = :AUPR
+)
+
+ # Delegates to Metrics submodule
+ computeMacroMetrics(networkFile, goldStandard; metric = metric)
+end
+
+
+# -----------------------------------------------------------------------------
+# CONVENIENCE · Full pipeline in one call
+# -----------------------------------------------------------------------------
+"""
+ inferGRN(exprFile, targFile, regFile, priorFile; outputDir="results", kwargs...)
+
+Run the complete InferelatorJL pipeline end-to-end:
+
+1. Load & filter expression data
+2. Merge degenerate TFs + process prior
+3. Estimate TFA
+4. Build TFA-mode network
+5. Build mRNA-mode network
+6. Aggregate both networks
+7. Recalculate TFA on combined network
+
+All keyword arguments are forwarded to the relevant sub-functions.
+
+# Returns
+`BuildGrn` from the TFA-mode network (combined results written to `outputDir`)
+"""
+function inferGRN(
+ exprFile::String,
+ targFile::String,
+ regFile::String,
+ priorFile::String;
+ outputDir::String = "results",
+ tfaGeneFile::String = "",
+ epsilon::Float64 = 0.01,
+ minTargets::Int = 3,
+ edgeSS::Int = 0,
+ zScoreTFA::Bool = true,
+ zScoreLASSO::Bool = true,
+ priorFilePenalties::Vector{String} = String[],
+ lambdaBias::Vector{Float64} = [0.5],
+ totSS::Int = 80,
+ bstarsTotSS::Int = 5,
+ subsampleFrac::Float64 = 0.63,
+ minLambda::Float64 = 0.01,
+ maxLambda::Float64 = 0.5,
+ totLambdasBstars::Int = 20,
+ totLambdas::Int = 40,
+ targetInstability::Float64 = 0.05,
+ meanEdgesPerGene::Int = 20,
+ correlationWeight::Int = 1,
+ instabilityLevel::String = "Network",
+ useMeanEdgesPerGeneMode::Bool = true,
+ combineMethod::Symbol = :max
+)::BuildGrn
+
+ mkpath(outputDir)
+ tfaDir = joinpath(outputDir, "TFA")
+ mRNADir = joinpath(outputDir, "TFmRNA")
+ combDir = joinpath(outputDir, "Combined")
+ mkpath(tfaDir); mkpath(mRNADir); mkpath(combDir)
+
+ # Shared kwargs for buildNetwork
+ netKwargs = (
+ priorFilePenalties = isempty(priorFilePenalties) ? [priorFile] : priorFilePenalties,
+ lambdaBias = lambdaBias,
+ totSS = totSS,
+ bstarsTotSS = bstarsTotSS,
+ subsampleFrac = subsampleFrac,
+ minLambda = minLambda,
+ maxLambda = maxLambda,
+ totLambdasBstars = totLambdasBstars,
+ totLambdas = totLambdas,
+ targetInstability = targetInstability,
+ meanEdgesPerGene = meanEdgesPerGene,
+ correlationWeight = correlationWeight,
+ instabilityLevel = instabilityLevel,
+ useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode,
+ zScoreLASSO = zScoreLASSO,
+ )
+
+ # Steps 1–3
+ data = loadData(exprFile, targFile, regFile;
+ tfaGeneFile = tfaGeneFile, epsilon = epsilon)
+ priorData, mergedTFs = loadPrior(data, priorFile; minTargets = minTargets)
+ estimateTFA(priorData, data; edgeSS = edgeSS, zScoreTFA = zScoreTFA,
+ outputDir = outputDir)
+
+ # Step 4 — TFA mode
+ tfaGrn = buildNetwork(data, priorData; tfaMode = true,
+ netKwargs..., outputDir = tfaDir)
+
+ # Step 4 — mRNA mode
+ mrnaGrn = buildNetwork(data, priorData; tfaMode = false,
+ netKwargs..., outputDir = mRNADir)
+
+ # Step 5 — aggregate
+ aggregateNetworks(
+ [joinpath(tfaDir, "edges.tsv"),
+ joinpath(mRNADir, "edges.tsv")];
+ method = combineMethod,
+ meanEdgesPerGene = meanEdgesPerGene,
+ useMeanEdgesPerGene = useMeanEdgesPerGeneMode,
+ outputDir = combDir
+ )
+
+ # Step 6 — refine TFA
+ combinedSparse = joinpath(combDir, "combined_" * string(combineMethod) * "_sp.tsv")
+ refineTFA(combinedSparse, data, mergedTFs;
+ tfaGeneFile = tfaGeneFile,
+ edgeSS = edgeSS,
+ minTargets = minTargets,
+ zScoreTFA = zScoreTFA, # API wrapper translates → zTarget internally
+ exprFile = exprFile,
+ targFile = targFile,
+ regFile = regFile,
+ outputDir = combDir)
+
+ return tfaGrn
+end
diff --git a/src/InferelatorJL.jl b/src/InferelatorJL.jl
new file mode 100755
index 0000000..7f24e38
--- /dev/null
+++ b/src/InferelatorJL.jl
@@ -0,0 +1,84 @@
+module InferelatorJL
+
+# ── Dependencies ──────────────────────────────────────────────────────────────
+using ArgParse, Arrow, CSV, CategoricalArrays, Colors
+using DataFrames, Distributions, FileIO, GLMNet
+using InlineStrings, Interpolations, JLD2, Measures
+using NamedArrays, OrderedCollections, ProgressBars
+using PyPlot, Random, SparseArrays, StatsBase, TickTock
+
+# ── Structs (always first — all type definitions in one place) ────────────────
+include("Types.jl") # GeneExpressionData, mergedTFsResult, PriorTFAData, GrnData, BuildGrn
+
+# ── Functions ─────────────────────────────────────────────────────────────────
+include("data/GeneExpressionData.jl") # data loading functions
+include("prior/MergeDegenerateTFs.jl") # TF merging functions
+include("data/PriorTFAData.jl") # prior processing + TFA functions
+
+# ── Core pipeline ─────────────────────────────────────────────────────────────
+include("utils/DataUtils.jl") # Data transformation utilities
+include("utils/NetworkIO.jl") # Save/write network files
+include("utils/PartialCorrelation.jl") # Partial correlation via precision matrix
+
+include("grn/PrepareGRN.jl") # preparePredictorMat!, preparePenaltyMatrix!, constructSubsamples
+include("grn/BuildGRN.jl") # bstarsWarmStart, bstartsEstimateInstability, chooseLambda!, rankEdges!
+include("grn/AggregateNetworks.jl") # combineGRNs / aggregateNetworks
+include("grn/RefineTFA.jl") # combineGRNS2 / refineTFA
+include("grn/UtilsGRN.jl") # GRN utility helpers
+
+# ── Metrics ───────────────────────────────────────────────────────────────────
+include("metrics/Constants.jl")
+include("metrics/MetricUtils.jl")
+include("metrics/CalcPR.jl")
+include("metrics/Metrics.jl")
+include("metrics/plotting/PlotSingle.jl")
+include("metrics/plotting/PlotBatch.jl")
+
+# ── Public API ────────────────────────────────────────────────────────────────
+include("API.jl")
+
+# ── Exports ───────────────────────────────────────────────────────────────────
+export
+ # Structs
+ GeneExpressionData,
+ mergedTFsResult,
+ PriorTFAData,
+ GrnData,
+ BuildGrn,
+
+ # High-level API (src/API.jl)
+ loadData,
+ loadPrior,
+ estimateTFA,
+ buildNetwork,
+ aggregateNetworks,
+ refineTFA,
+ evaluateNetwork,
+ inferGRN,
+
+ # Data utilities
+ convertToLong,
+ convertToWide,
+ frobeniusNormalize,
+ completeDF,
+ mergeDFs,
+ check_column_norms,
+ writeTSVWithEmptyFirstHeader,
+ binarizeNumeric!,
+
+ # I/O
+ saveData,
+ writeNetworkTable!
+
+ # -------------------------------------------------------------------------
+ # Internal pipeline functions are intentionally NOT exported.
+ # They remain accessible via InferelatorJL. if needed.
+ # loadExpressionData! loadAndFilterTargetGenes! loadPotentialRegulators!
+ # processTFAGenes! processPriorFile! mergeDegenerateTFs
+ # calculateTFA! preparePredictorMat! preparePenaltyMatrix!
+ # constructSubsamples bstarsWarmStart bstartsEstimateInstability
+ # chooseLambda! rankEdges!
+ # -------------------------------------------------------------------------
+ # -------------------------------------------------------------------------
+
+end # module InferelatorJL
diff --git a/src/Types.jl b/src/Types.jl
new file mode 100755
index 0000000..5286d92
--- /dev/null
+++ b/src/Types.jl
@@ -0,0 +1,166 @@
+# =============================================================================
+# Types.jl — All struct definitions for InferelatorJL
+#
+# Included FIRST in InferelatorJL.jl so every function file can freely
+# reference any type without worrying about include order.
+# =============================================================================
+
+
+# ── Data loading ──────────────────────────────────────────────────────────────
+
+mutable struct GeneExpressionData
+ cellLabels::Vector{String}
+ geneNames::Vector{String}
+ geneExpressionMat::Matrix{Float64}
+ potRegMatmRNA::Matrix{Float64}
+ potRegs::Vector{String}
+ potRegsmRNA::Vector{String}
+ targGenes::Vector{String}
+ targGeneMat::Matrix{Float64}
+ tfaGenes::Vector{String}
+ tfaGeneMat::Matrix{Float64}
+
+ function GeneExpressionData()
+ return new(
+ [],
+ [],
+ Matrix{Float64}(undef, 0, 0),
+ Matrix{Float64}(undef, 0, 0),
+ [],
+ [],
+ [],
+ Matrix{Float64}(undef, 0, 0),
+ [],
+ Matrix{Float64}(undef, 0, 0)
+ )
+ end
+end
+
+
+# ── Prior / TF merging ────────────────────────────────────────────────────────
+
+mutable struct mergedTFsResult
+ mergedPrior::Union{DataFrame, Nothing}
+ mergedTFMap::Union{Matrix{String}, Nothing}
+
+ function mergedTFsResult()
+ return new(nothing, nothing)
+ end
+end
+
+
+# ── TFA ───────────────────────────────────────────────────────────────────────
+
+mutable struct PriorTFAData
+ pRegs::Vector{String}
+ pTargs::Vector{String}
+ priorMatrix::Matrix{Float64}
+ pRegsNoTfa::Vector{String}
+ pTargsNoTfa::Vector{String}
+ priorMatrixNoTfa::Matrix{Float64}
+ noPriorRegs::Vector{String}
+ noPriorRegsMat::Matrix{Float64}
+ targExpression::Matrix{Float64}
+ medTfas::Matrix{Float64}
+
+ function PriorTFAData()
+ return new(
+ [],
+ [],
+ Matrix{Float64}(undef, 0, 0),
+ [],
+ [],
+ Matrix{Float64}(undef, 0, 0),
+ [],
+ Matrix{Float64}(undef, 0, 0),
+ Matrix{Float64}(undef, 0, 0),
+ Matrix{Float64}(undef, 0, 0)
+ )
+ end
+end
+
+
+# ── GRN ───────────────────────────────────────────────────────────────────────
+
+mutable struct GrnData
+ predictorMat::Matrix{Float64}
+ penaltyMat::Matrix{Float64}
+ allPredictors::Vector{String}
+ subsamps::Matrix{Int64}
+ responseMat::Matrix{Float64}
+ maxLambdaNet::Float64
+ minLambdaNet::Float64
+ minLambdas::Matrix{Float64}
+ maxLambdas::Matrix{Float64}
+ netInstabilitiesUb::Vector{Float64}
+ netInstabilitiesLb::Vector{Float64}
+ instabilitiesUb::Matrix{Float64}
+ instabilitiesLb::Matrix{Float64}
+ netInstabilities::Vector{Float64}
+ geneInstabilities::Matrix{Float64}
+ lambdaRange::Vector{Float64}
+ lambdaRangeGene::Vector{Vector{Float64}}
+ stabilityMat::Array{Float64}
+ priorMatProcessed::Matrix{Float64}
+ betas::Array{Float64,3}
+
+ function GrnData()
+ return new(
+ Matrix{Float64}(undef, 0, 0), # predictorMat
+ Matrix{Float64}(undef, 0, 0), # penaltyMat
+ [], # allPredictors
+ Matrix{Int64}(undef, 0, 0), # subsamps (correctly Int64)
+ Matrix{Float64}(undef, 0, 0), # responseMat
+ 0.0, # maxLambdaNet
+ 0.0, # minLambdaNet
+ Matrix{Float64}(undef, 0, 0), # minLambdas
+ Matrix{Float64}(undef, 0, 0), # maxLambdas
+ [], # netInstabilitiesUb
+ [], # netInstabilitiesLb
+ Matrix{Float64}(undef, 0, 0), # instabilitiesUb
+ Matrix{Float64}(undef, 0, 0), # instabilitiesLb
+ [], # netInstabilities
+ Matrix{Float64}(undef, 0, 0), # geneInstabilities
+ [], # lambdaRange
+ Vector{Vector{Float64}}(undef, 0), # lambdaRangeGene
+ Array{Float64}(undef, 0, 0, 0), # stabilityMat (3-D)
+ Matrix{Float64}(undef, 0, 0), # priorMatProcessed
+ Array{Float64,3}(undef, 0, 0, 0) # betas
+ )
+ end
+end
+
+
+mutable struct BuildGrn
+ networkStability::Matrix{Float64}
+ lambda::Union{Float64, Vector{Float64}}
+ targs::Vector{String}
+ regs::Vector{String}
+ rankings::Vector{Float64}
+ signedQuantile::Vector{Float64}
+ partialCorrelation::Vector{Float64}
+ inPrior::Vector{String}
+ networkMat::Matrix{Any}
+ networkMatSubset::Matrix{Any}
+ inPriorVec::Vector{Float64}
+ betas::Matrix{Float64}
+ mergeTfLocVec::Vector{Float64}
+
+ function BuildGrn()
+ return new(
+ Matrix{Float64}(undef, 0, 0), # networkStability
+ 0.0, # lambda
+ [], # targs
+ [], # regs
+ [], # rankings
+ [], # signedQuantile
+ [], # partialCorrelation
+ [], # inPrior
+ Matrix{Float64}(undef, 0, 0), # networkMat
+ Matrix{Float64}(undef, 0, 0), # networkMatSubset
+ [], # inPriorVec
+ Matrix{Float64}(undef, 0, 0), # betas
+ Float64[] # mergeTfLocVec
+ )
+ end
+end
diff --git a/src/Utils/dataUtils.jl b/src/Utils/dataUtils.jl
new file mode 100755
index 0000000..5ad07d5
--- /dev/null
+++ b/src/Utils/dataUtils.jl
@@ -0,0 +1,326 @@
+using DataFrames
+using CSV
+using LinearAlgebra
+using TickTock
+using FileIO
+
+# Convert Wide/long data to long/wide
+function convertToLong(data)
+#dfs = [convertToLong(df) for df in dfs]
+ if ncol(data) > 3
+ return stack(data, Not(1) )
+ else
+ return data
+ end
+end
+
+ """
+ Converts a 3‑column long-format DataFrame to wide‑format using the unstack function.
+ If the input DataFrame has exactly 3 columns, the conversion is performed.
+
+ - Data columns is a wide matrix with columns as TF and rows as target genes or
+ a long data with columns in the order TF, Gene, Weights.
+ indices::Union{Nothing, NTuple{3, Int}} = nothing:
+ A tuple specifying the column indices in the order (pivot, key, value).
+ - pivot: The column that provides the row identifier.
+ - key: The column whose unique values will become new column names.
+ - value: The column from which the cell values are taken.
+
+ If no indices are provided, the function defaults to (1, 2, 3).
+ If the DataFrame has more than 3 columns, the original DataFrame is returned.
+ If the DataFrame has less than 3 columns, an error is thrown.
+
+ # USAGE
+ dfs = [convertToWide(df) for df in dfs]
+ dfs = [convertToWide(df; indices = (1,2,3)) for df in dfs]
+ """
+function convertToWide(Data; indices::Union{Nothing, NTuple{3, Int}}=nothing)
+ ncols = ncol(Data)
+
+ if ncols < 3
+ error("DataFrame has less than 3 columns. A 3‑column DataFrame or a wide-matrix is required.")
+ elseif ncols > 3
+ # More than 3 columns: return data unchanged.
+ return Data
+ else # Exactly 3 columns
+ # Use provided indices, or default to (2, 1, 3)
+ inds = isnothing(indices) ? (2,1,3) : indices
+ # Convert the specified columns to Symbols for unstack.
+ idSym = Symbol(names(Data)[inds[1]])
+ keySym = Symbol(names(Data)[inds[2]])
+ valueSym = Symbol(names(Data)[inds[3]])
+
+ return unstack(Data, idSym, keySym, valueSym)
+ end
+end
+
+
+
+
+"""
+frobeniusNormalize(df::DataFrame; dims::Symbol = :row)
+
+Normalize a DataFrame based on the Frobenius (L2) norm.
+
+• If dims = :row (default), it normalizes each row (ignoring the first column).
+• If dims = :column, it normalizes each column (ignoring the first column).
+
+The first column is assumed to contain non‐numeric data (e.g., row identifiers)
+and is left unchanged.
+"""
+function frobeniusNormalize(df::DataFrame, dims::Symbol =:column)
+ dfNorm = deepcopy(df)
+ dfMat = Matrix(dfNorm[!, 2:end]) # Extract numerical part
+ dfMat = convert(Matrix{Float64}, dfMat) # Ensure it's Float64
+
+ if dims == :row
+ # Normalize each row.
+ for i in 1:size(dfMat, 1)
+ nrm = norm(dfMat[i, :], 2) # Compute the L2 norm for the row.
+ if nrm != 0
+ dfMat[i, :] ./= nrm
+ end
+ end
+ elseif dims == :column
+ # Normalize each column.
+ for j in 1:size(dfMat, 2)
+ nrm = norm(dfMat[:, j], 2) # Compute the L2 norm for the column.
+ if nrm != 0
+ dfMat[:, j] ./= nrm
+ end
+ end
+ else
+ throw(ArgumentError("dims must be :row or :column"))
+ end
+
+ # Update the DataFrame with the normalized numeric values.
+ dfNorm[!, 2:end] .= dfMat
+ return dfNorm
+end
+
+
+"""
+completeDF(df::DataFrame, id::Symbol, idsALL, allCols)
+
+Aligns a DataFrame to a common set of row identifiers and columns.
+
+### Arguments:
+- `df::DataFrame`: The input DataFrame to align.
+- `id::Symbol`: The identifier column (e.g., gene or sample ID).
+- `idsALL`: A vector of all unique row identifiers across multiple DataFrames.
+- `allCols`: A vector of all unique column names across multiple DataFrames.
+
+### Returns:
+- A new DataFrame with:
+- Rows corresponding to `idsALL` (missing rows filled with `missing`).
+- Columns matching `allCols` (missing columns filled with `missing`).
+- The original values retained where available.
+
+"""
+function completeDF(df::DataFrame, id::Symbol, idsALL, allCols)
+ dfNew = DataFrame()
+ dfNew[!, id] = idsALL # Ensure all IDs are included
+
+ for col in allCols
+ if col in Symbol.(names(df))
+ mapping = Dict(row[id] => row[col] for row in eachrow(df)) # Store existing values
+ dfNew[!, col] = [get(mapping, rid, missing) for rid in idsALL] # Align data
+ else
+ dfNew[!, col] = fill(missing, length(idsALL)) # Fill missing columns
+ end
+ end
+ return dfNew
+end
+
+"""
+"Merges a vector of DataFrames by summing or averaging them cell-wise.
+
+Arguments:
+- dfs::Vector{DataFrame}: A vector of DataFrames to merge.
+- id::Symbol: The identifier column (common to every DataFrame).
+- option::String: The operation to perform: either "sum" or "avg".
+
+Returns:
+A DataFrame where the first column is the identifier and the remaining columns are the cell-wise
+sum or average (depending on the option) of the numeric values from all data frames (with columns sorted in alphabetical order).
+"""
+function mergeDFs(dfs::Vector{DataFrame}, id::Symbol = :Gene, option::String = "sum")
+
+ # 1. Compute the full set of row identifiers
+ idsALL = reduce(union, [unique(df[!, id]) for df in dfs])
+ sort!(idsALL)
+
+ # 2. Compute the full set of numeric columns
+ allCols = reduce(union, [setdiff(names(df), [string(id)]) for df in dfs])
+ allCols = sort(Symbol.(allCols))
+
+ # 3. Complete all DataFrames to have the same structure
+ completed_dfs = [completeDF(df, id, idsALL, allCols) for df in dfs]
+
+ # 4. Initialize matrices for sum and count tracking
+ nRows, nCols = length(idsALL), length(allCols)
+ sumMat = zeros(nRows, nCols)
+ countMat = zeros(Int, nRows, nCols)
+
+ # 5. Compute sum and count matrices
+ for df in completed_dfs
+ mat = Matrix(df[!, allCols])
+ mat = coalesce.(mat, 0.0)
+ sumMat .+= mat
+ countMat .+= (mat .!= 0) # Count nonzero interactions
+ end
+
+ # 6. Compute sum or average
+ resultMat = option == "avg" ? sumMat ./ max.(countMat, 1) : sumMat
+
+ # 7. Create the merged DataFrame
+ merged = DataFrame(id => idsALL)
+ for (j, col) in enumerate(allCols)
+ merged[!, col] = resultMat[:, j]
+ end
+
+ return merged
+end
+
+"""
+Checks if the columns are of length 1 (L2 Norm). Returns the number of true and false cases
+Checks if each numeric column (except the first) is L2 normalized (norm approx 1).
+Returns a Dict of column name => Bool, and counts of true and false cases.
+"""
+function check_column_norms(df::DataFrame; atol=1e-8)
+ details = Dict{String,Bool}()
+ count_true = 0
+ count_false = 0
+
+ for col in names(df)[2:end]
+ # Check if the column is numeric by testing its element type.
+ if !(eltype(df[!, col]) <: Number)
+ println("Column: ", col, " is non-numeric, skipping...")
+ continue
+ end
+ # Convert the column values to Float64.
+ values = Float64.(df[!, col])
+ # Compute the L₂ norm.
+ col_norm = norm(values, 2)
+ # Check whether the norm is approximately 1.
+ is_norm_one = isapprox(col_norm, 1.0; atol=atol)
+
+ details[string(col)] = is_norm_one
+ if is_norm_one
+ count_true += 1
+ else
+ count_false += 1
+ end
+ # println("Column: ", col, " Norm: ", col_norm, " ~ 1? ", is_norm_one)
+ end
+
+ println("Total numeric columns normalized (True): ", count_true)
+ println("Total numeric columns not normalized (False): ", count_false)
+ return details, count_true, count_false
+end
+
+function writeTSVWithEmptyFirstHeader(df::DataFrame, filepath::String; delim::Char='\t')
+ open(filepath, "w") do io
+ # Create custom header:
+ # Make the first header cell an empty string.
+ # The remaining header cells are the remaining column names.
+ hdr = [""; names(df)[2:end]]
+ # Write the header row using the delimiter.
+ println(io, join(hdr, delim))
+
+ # Write each row. Each row is converted into strings.
+ for row in eachrow(df)
+ # Convert every element of the row to a string.
+ row_values = [string(x) for x in collect(row)]
+ println(io, join(row_values, delim))
+ end
+ end
+end
+
+function binarizeNumeric!(data)
+ if isa(data, DataFrame)
+ for col in names(data)
+ if eltype(data[!, col]) <: Number
+ data[!, col] .= Int.(data[!, col] .!= 0)
+ end
+ end
+ return data
+ elseif isa(data, AbstractMatrix)
+ data .= Int.(data .!= 0)
+ return data
+ else
+ throw(ArgumentError("Input must be DataFrame or matrix"))
+ end
+end
+
+
+# function binarizeNonZero!(data)
+# if isa(data, DataFrame)
+# for c in names(data)
+# if eltype(data[!, c]) <: Number
+# data[!, c] .= ifelse.(data[!, c] .!= 0, 1, 0)
+# end
+# end
+# elseif isa(data, AbstractMatrix)
+# data .= ifelse.(data .!= 0, 1, 0)
+# else
+# throw(ArgumentError("Input must be a DataFrame or a Matrix"))
+# end
+# return data
+# end
+
+
+#=
+# Test workflow with small data
+
+Example DataFrames with row names (represented as the first column in the DataFrame)
+df1 = DataFrame(RowName = ["W", "X"], A = [1, 2], B = [3, 4], C = [5, 6])
+df2 = DataFrame(Gene = ["Y", "X"], B = [7, 8], C = [9, 10], D = [11, 12])
+df3 = DataFrame(Blue = ["P", "Q"], A = [13, 14], D = [15, 16], E = [17,18])
+
+# # # # List of all DataFrames you want to combine
+dfs = [df1, df2, df3]
+dfs = [convertToLong(df) for df in dfs]
+dfs = [convertToWide(df; indices = (1,2,3)) for df in dfs]
+
+tick()
+dfNorm = [frobeniusNormalize1(df, :row) for df in dfs]
+tock()
+
+commonID = :Gene
+dfNorm = [rename!(df, names(df)[1] => commonID) for df in dfNorm]
+
+tick()
+mergeDFs(dfNorm, :Gene, "sum")
+tock()
+=#
+
+
+#Compute Frobenius norm for each row (excluding the first column)
+# norm_dfs = [DataFrame(A = df[!, 1], FrobeniusNorm = [norm(row[2:end], 2) for row in eachrow(df)]) for df in dfs] # row
+# norm_dfs = [DataFrame(Column = names(df)[2:end], FrobeniusNorm = [ norm(col, 2) for col in eachcol(df[!, 2:end]) ]) for df in dfs] # col
+
+
+# df1 = "/data/miraldiNB/Michael/projects/GRN/hCD4T_Katko/dataBank/Priors/SCENICp/majorCellType/tf2gene_only_TF2G_wide_binary.tsv"
+# df1 = CSV.read(df1, DataFrame; delim = "\t")
+# df1Norm = frobeniusNormalize(df1, :column)
+# check_column_norms(df1Norm; atol=1e-3)
+
+# df2 = "/data/miraldiNB/Michael/projects/GRN/hCD4T_Katko/dataBank/Priors/MotifScan5kbTSS_b.tsv"
+# df2 = CSV.read(df2, DataFrame; delim = "\t")
+# check_column_norms(df2; atol=1e-8)
+
+# dfs = [df1, df2]
+# dfs = [CSV.read(df1, DataFrame; delim = "\t") for df in dfpath]
+# merged = mergeDFs(dfs, :Column1, "sum")
+# mergedNorm = frobeniusNormalize(merged, :column)
+# check_column_norms(mergedNorm; atol=1e-8)
+# maximum(Matrix(merged_b[:, 2:end]))
+
+# merged_b = binarizeNonZero!(merged)
+# # order = sortperm(merged_b[:, "Column1"], rev = false)
+# # merged_b = merged_b[order, :]
+# writeTSVWithEmptyFirstHeader(merged_b, "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/tf2gene_5kbTSS_b.tsv"; delim='\t')
+
+
+
diff --git a/src/Utils/networkIO.jl b/src/Utils/networkIO.jl
new file mode 100755
index 0000000..8c6d6cb
--- /dev/null
+++ b/src/Utils/networkIO.jl
@@ -0,0 +1,28 @@
+using JLD2
+using DelimitedFiles
+using Printf
+using Dates
+
+# Save all core structs
+function saveData(expressionData, tfaData, grnData, buildGrn, outputDir::String, fileName::String)
+ output = joinpath(outputDir, fileName)
+ @save output expressionData tfaData grnData buildGrn
+end
+
+# Write network tables
+function writeNetworkTable!(buildGrn; outputDir::String, networkName::Union{String, Nothing}=nothing)
+ baseName = (networkName === nothing || isempty(networkName)) ? "edges" : networkName * "edges"
+
+ outputFile = joinpath(outputDir, baseName * ".tsv")
+ colNames = "TF\tGene\tsignedQuantile\tStability\tCorrelation\tinPrior\n"
+ open(outputFile, "w") do io
+ write(io, colNames)
+ writedlm(io, buildGrn.networkMat)
+ end
+
+ outputFileSubset = joinpath(outputDir, baseName * "_subset.tsv")
+ open(outputFileSubset, "w") do io
+ write(io, colNames)
+ writedlm(io, buildGrn.networkMatSubset)
+ end
+end
diff --git a/src/Utils/partialCorrelation.jl b/src/Utils/partialCorrelation.jl
new file mode 100755
index 0000000..be8c2a6
--- /dev/null
+++ b/src/Utils/partialCorrelation.jl
@@ -0,0 +1,96 @@
+
+# Method 1: Matrix Inversion Method:
+# this methods computes partial correlation from precision-matrix (the inverse of variance-covariance matrix)
+using LinearAlgebra
+function partialCorrelationMat(X::Matrix{Float64}; epsilon = 1e-7, first_vs_all = false)
+ # Mean centering of the columns (mean subtraction)
+ X_centered = X .- mean(X, dims=1)
+
+ # Compute the covariance matrix
+ sigma = cov(X_centered)
+ # regularizing the covariance matrix to avoid ill-conditining and singularity
+ sigma = sigma + epsilon * I
+
+ # Precision matrix (inverse of the covariance matrix)
+ theta = inv(sigma)
+ p = size(X, 2) # number of features/predictors/explanatory variables
+
+ if first_vs_all
+ # Compute partial correlation for the first variable vs all others
+ P = ones(1, p) # Initialize the partial correlation matrix
+ for j in 2:p
+ P[j] = -theta[1, j] / sqrt(theta[1, 1] * theta[j, j])
+ end
+ else
+ # Compute the full partial correlation matrix
+ P = ones(p, p) # Initialize the partial correlation matrix
+ for i in 1:(p-1)
+ for j in (i+1):p
+ P[i , j] = -theta[i, j] / sqrt(theta[i, i] * theta[j, j])
+ P[j, i] = P[i, j] # Symmetry
+ end
+ end
+ end
+
+ return P
+end
+
+
+# Method 2: Using Regression
+using GLM, Statistics, DataFrames
+
+function partialCorrReg(X::Matrix{Float64}; first_vs_all = false)
+ p = size(X, 2)
+ P = ones(p, p)
+
+ # Convert the matrix to a DataFrame for easier variable handling
+ df = DataFrame(X, :auto);
+ col_names = names(df)
+
+ if first_vs_all
+ P = ones(1, p)
+ for j in 2:p
+ keep_indices = setdiff(2:p, j)
+ covariates = col_names[keep_indices] # All columns except i and j
+
+ # First, regress the first covariate on all other features except
+ model_1 = lm(term.(col_names[1]) ~ sum(term.(covariates)), df)
+ res_1 = residuals(model_1) # Get residuals for 1
+
+ # Then, regress j on all others except i
+ model_j = lm(term.(col_names[j]) ~ sum(term.(covariates)), df)
+ res_j = residuals(model_j) # Get residuals for j
+
+ # Compute c1rrelation between the residuals
+ r = cor(res_1, res_j)
+ P[j] = r
+ end
+ else
+ P = ones(p, p)
+ for i in 1:p
+ for j in (i+1):p
+
+ keep_indices = setdiff(1:p, [i ,j])
+ covariates = col_names[keep_indices] # All columns except i and j
+
+ # reference: https://discourse.julialang.org/t/using-all-independent-variables-with-formula-in-a-multiple-linear-model/43691/4
+
+ # First, regress i on all other features except j (i is the response)
+ model_i = lm(term.(col_names[i]) ~ sum(term.(covariates)), df)
+ res_i = residuals(model_i) # Get residuals for i
+
+ # Then, regress j on all others except i
+ model_j = lm(term.(col_names[j]) ~ sum(term.(covariates)), df)
+ res_j = residuals(model_j) # Get residuals for j
+
+ # Compute correlation between the residuals
+ r = cor(res_i, res_j)
+ P[i, j] = r
+ P[j, i] = r # Symmetric matrix
+ end
+ end
+ end
+ return P
+end
+
+
diff --git a/src/data/GeneExpressionData.jl b/src/data/GeneExpressionData.jl
new file mode 100755
index 0000000..01ef5dd
--- /dev/null
+++ b/src/data/GeneExpressionData.jl
@@ -0,0 +1,205 @@
+# 00_Data/geneExpression.jl
+
+ using DelimitedFiles
+ using Statistics
+ using JLD2
+ using CSV
+ using Arrow
+ using DataFrames
+
+ # Struct defined in src/Types.jl
+
+
+ # Load Expression Data
+ """
+ loadExpressionData!(data::GeneExpressionData, geneExprFile)
+
+ Loads gene expression data from a specified file and updates the `GeneExpressionData` object.
+
+ # Arguments
+ - `data::GeneExpressionData`: The data object to be populated with gene expression details.
+ - `geneExprFile::String`: The path to the gene expression data file, which can be in `.arrow` format or tab-delimited text format.
+
+ # Updates
+ - `data.cellLabels`: A vector of sample condition names, derived from the column headers (excluding the first).
+ - `data.targGenes`: A vector of gene names, derived from the first column.
+ - `data.targGeneMat`: A matrix of expression values, converted to `Float64`.
+
+ # Raises
+ - An error if the gene expression file does not exist or if its path is invalid.
+
+ # Notes
+ - Assumes that the first column in the file contains gene names and subsequent columns contain cell labels and expression data.
+ - For text files, data is sorted by gene names.
+ """
+ function loadExpressionData!(data::GeneExpressionData, geneExprFile)
+ if isfile(geneExprFile)
+ if endswith(geneExprFile, ".arrow")
+ dfArrow = Arrow.Table(geneExprFile)
+ df = deepcopy(DataFrame(dfArrow))
+ dfArrow = nothing;
+ cellLabels = names(df)[2:end] # Assume first column is gene names
+ geneNames = df[:, 1] # Extract gene names
+ ncounts = Matrix(df[:, 2:end])
+ df = nothing
+ else
+ fid = open(geneExprFile);
+ C = readdlm(fid, '\t', '\n');
+ close(fid)
+ cellLabels = C[1, :];
+ cellLabels = filter(!isempty, cellLabels);
+ C = C[2:end, :];
+ inds = sortperm(C[:, 1]);
+ C = C[inds, :];
+ geneNames = C[:, 1];
+ ncounts = C[:, 2:end]
+ end
+ ncounts = convert(Matrix{Float64}, ncounts);
+
+ data.cellLabels = String.(cellLabels);
+ data.geneNames = String.(geneNames);
+ data.geneExpressionMat = ncounts;
+ ncounts = nothing;
+ else
+ error("Expression data file path is invalid.")
+ end
+ end
+
+
+ """
+ loadAndFilterTargetGenes!(data::GeneExpressionData, targetGeneFile; eps=1E-10)
+
+ Loads target genes from a file, filters them based on presence and variance in the existing gene expression data, and updates the `GeneExpressionData` object.
+
+ # Arguments
+ - `data::GeneExpressionData`: The data object to be updated with filtered target gene information.
+ - `targetGeneFile::String`: The path to the file containing target gene names.
+ - `eps::Float64`: A keyword argument specifying the variance threshold for filtering target genes. Default is `1E-10`.
+
+ # Updates
+ - `data.targGenes`: A filtered vector of target gene names that are present in the gene expression data and meet the variance threshold.
+ - `data.targGeneMat`: A matrix of expression values for these filtered target genes.
+
+ # Raises
+ - An error if the target gene file does not exist.
+ - An error if no target genes are found in the gene expression data.
+ - An error if all target genes are filtered out due to low variance.
+
+ # Notes
+ - This function finds matching genes and filters them based on a variance threshold to ensure sufficient expression variability across samples.
+ """
+ function loadAndFilterTargetGenes!(data::GeneExpressionData, targetGeneFile; epsilon=1E-10)
+ if isfile(targetGeneFile)
+ # Load in target gene file
+ fid = open(targetGeneFile)
+ targetGenes = readdlm(fid, String)
+ close(fid)
+
+ # Find all geneNames that are in the targer gene file
+ inds = findall(in(targetGenes), data.geneNames)
+ if isempty(inds)
+ error("No target genes found in expression data!")
+ end
+
+ # Filter target genes and expression matrix
+ targGenes = data.geneNames[inds]
+ targGeneMat = data.geneExpressionMat[inds, :]
+
+ # Filter genes by minimum variance cutoff
+ stds = std(targGeneMat, dims=2)
+ keep = [index[1] for index in findall(stds .>= epsilon)]
+ targGenesFilter = targGenes[keep]
+ targGeneMatFilter = targGeneMat[keep, :]
+ if isempty(keep)
+ error("All target genes removed due to low variance!")
+ else
+ println(length(targGenesFilter), " target genes retained after filtering")
+ end
+ data.targGenes = targGenesFilter
+ data.targGeneMat = targGeneMatFilter
+ else
+ error("Target gene file not found.")
+ end
+ end
+
+ # Load Regulators
+ """
+ loadPotentialRegulators!(data::GeneExpressionData, potRegFile)
+
+ Loads potential regulators and updates the `GeneExpressionData` object with their expression data.
+
+ # Arguments
+ - `data::GeneExpressionData`: The data object to be populated with potential regulator information.
+ - `potRegFile::String`: The path to the file containing potential regulator names.
+
+ # Updates
+ - `data.potRegs`: List of all potential regulator names.
+ - `data.potRegsmRNA`: List of regulator names with corresponding mRNA expression data.
+ - `data.potRegMatmRNA`: Expression matrix for regulators with mRNA data.
+
+ # Raises
+ - An error if the potential regulators file does not exist.
+
+ # Notes
+ - Only the potential regulators present in the existing gene expression data are included.
+ """
+ function loadPotentialRegulators!(data::GeneExpressionData, potRegFile)
+ if isfile(potRegFile)
+ fid = open(potRegFile)
+ potentialRegs = readdlm(fid, String)
+ close(fid)
+ potentialRegs = vec(potentialRegs)
+
+ indsPotRegs = findall(in(data.geneNames), potentialRegs)
+ potentialRegs = potentialRegs[indsPotRegs]
+
+ inds = findall(in(potentialRegs), data.geneNames)
+ potRegsmRNA = data.geneNames[inds]
+ potRegMatmRNA = data.geneExpressionMat[inds, :]
+
+ data.potRegs = potentialRegs
+ data.potRegsmRNA = potRegsmRNA
+ data.potRegMatmRNA = potRegMatmRNA
+ else
+ error("Potential regulators file not found.")
+ end
+ end
+
+
+
+ # Process genes for TFA calculation
+ """
+ processTFAGenes(file, geneSc, nCounts)
+
+ Processes TFA genes by retrieving their expression data from the expression matrix.
+
+ # Arguments
+ - `file::String`: The path to the TFA genes file. If the path is empty or invalid, all genes are used.
+ - `geneNames::Vector{String}`: A vector of gene names from the expression data.
+ - `geneExpressionMat::Matrix{Float64}`: A matrix containing expression values for the genes.
+
+ # Returns
+ - `tfaGenes::Vector{String}`: A vector of TFA gene names with expression data.
+ - `tfaGeneMat::Matrix{Float64}`: A matrix of expression values for the TFA genes.
+
+ # Notes
+ - If the specified file does not exist, all genes are considered for TFA.
+ """
+ function processTFAGenes!(data::GeneExpressionData, tfaGeneFile::Union{String, Nothing}; outputDir::Union{String, Nothing}=nothing)
+ if (tfaGeneFile !== nothing) && (tfaGeneFile != "") && isfile(tfaGeneFile)
+ tfaGenes = readlines(tfaGeneFile)
+ else
+ tfaGenes = data.geneNames
+ end
+
+ inds = findall(in(tfaGenes), data.geneNames)
+ data.tfaGenes = data.geneNames[inds]
+ data.tfaGeneMat = data.geneExpressionMat[inds, :]
+
+ if outputDir !== nothing && outputDir !== ""
+
+ outputFile = joinpath(outputDir, "geneExprMat.jld")
+ save_object(outputFile, data)
+
+ end
+ end
diff --git a/src/data/PriorTFAData.jl b/src/data/PriorTFAData.jl
new file mode 100755
index 0000000..51f46c5
--- /dev/null
+++ b/src/data/PriorTFAData.jl
@@ -0,0 +1,183 @@
+ # Standard libraries
+ using LinearAlgebra
+ using Statistics
+ using DelimitedFiles
+ using JLD2
+
+ # Struct defined in src/Types.jl
+
+ """
+ processPriorFile!(priorData::PriorTFAData, priorFile)
+
+ Reads and processes a prior file, storing the extracted information in a `PriorTFAData` object.
+
+ # Arguments
+ - `priorData::PriorTFAData`: The struct to be populated with data from the prior file.
+ - `priorFile::String`: The path to the prior file, formatted with tab-separated values.
+
+ # Updates
+ - `priorData.pRegs`: Stores the list of transcription factors (TFs) from the prior file.
+ - `priorData.pTargs`: Stores the list of target genes from the prior file.
+ - `priorData.priorMatrix`: Stores the interactions matrix, indicating TF-gene relationships.
+
+ # Raises
+ - An error if the prior file does not exist.
+
+ # Notes
+ - Assumes the first row in the file contains TF names, and subsequent rows represent interactions with target genes.
+ - The gene list and matrix are sorted alphabetically by gene names.
+ """
+ function processPriorFile!(priorData::PriorTFAData,
+ expressionData::GeneExpressionData,
+ priorFile; mergedTFsData::Union{mergedTFsResult, Nothing}=nothing, minTargets = 3)
+
+ if isfile(priorFile)
+
+ println("--- Case 1")
+ fid = open(priorFile)
+ C = readdlm(fid, '\t', '\n', skipstart=0)
+ close(fid)
+
+ # Process and store genes and interactions matrix
+ pRegsTmp = convert(Vector{String}, filter(!isempty, C[1, :]))
+ C = C[2:end, :]
+ inds = sortperm(C[:, 1])
+ C = C[inds, :]
+ pTargsTmp = convert(Vector{String}, C[:, 1])
+ pMatrixTmp = convert(Matrix{Float64}, C[:, 2:end])
+
+ # Filter to only include those in potential regulator and target gene list
+ targInds = findall(in(expressionData.tfaGenes), pTargsTmp)
+ regInds = findall(in(expressionData.potRegs), pRegsTmp)
+ pRegsNoTfa = pRegsTmp[regInds]
+ pTargsNoTfa = pTargsTmp[targInds]
+ priorMatrixNoTfa = pMatrixTmp[targInds, regInds]
+
+ # Find TFs that have expression data but arent in the prior
+ noPriorRegs = setdiff(expressionData.potRegsmRNA, pRegsNoTfa)
+ expInds = findall(in(noPriorRegs), expressionData.potRegsmRNA)
+ noPriorRegsMat = expressionData.potRegMatmRNA[expInds,:]
+
+ println("--- Case 2")
+ ## Case 2: prior-based TFA is used
+ # Check whether there were degenerate TFs and outputs
+ if (mergedTFsData !== nothing) &&
+ (mergedTFsData.mergedPrior !== nothing) &&
+ (mergedTFsData.mergedTFMap !== nothing)
+
+ println("-------- Using merge degenerate TFs prior file")
+ mergedTFs = mergedTFsData.mergedTFMap[:, 1]
+ individualTFs = mergedTFsData.mergedTFMap[:, 2]
+
+ totMergedSets = length(individualTFs)
+ keepMergedIndices = Int[]
+
+ for idx in 1:totMergedSets
+ currSet = split(individualTFs[idx], ", ")
+ usedTfs = intersect(currSet, expressionData.potRegs)
+ if !isempty(usedTfs)
+ push!(keepMergedIndices, idx)
+ end
+ end
+
+ if !isempty(keepMergedIndices) # add merged potential regulators to our list
+ expressionData.potRegs = union(expressionData.potRegs, mergedTFs[keepMergedIndices])
+ # Now load the merged prior matrix data (mergedPrior)
+ priorDF = mergedTFsData.mergedPrior
+ rowInd = sortperm(priorDF[:, 1])
+ pRegsTmp = names(priorDF)[2:end]
+ totPRegs = length(pRegsTmp)
+ pTargsTmp = priorDF[rowInd, 1]
+ pMatrixTmp = Matrix(priorDF[rowInd, 2:end])
+ end
+ end
+
+ ### Filter TFs and Genes for TFA calculation
+ pTargInds = findall(in(expressionData.tfaGenes), pTargsTmp)
+ pRegInds = findall(in(expressionData.potRegs), pRegsTmp)
+ pTargs = pTargsTmp[pTargInds]
+ pRegs = pRegsTmp[pRegInds]
+ pInts = pMatrixTmp[pTargInds,pRegInds]
+ # Filter for minimum target genes in prior
+ interactionsPerTF = sum((abs.(sign.(pInts))), dims=1)
+ keepRegs = Tuple.(findall(x -> x > minTargets, interactionsPerTF))
+ keepRegs = last.(keepRegs)
+ pInts = pInts[:,keepRegs]
+ pRegs = pRegs[keepRegs]
+ # Filter genes with no regulators
+ interactionsPerTarg = sum(abs.(sign.(pInts)), dims = 2)
+ keepTargs = Tuple.(findall(x -> x > 0, interactionsPerTarg))
+ keepTargs = first.(keepTargs)
+ pInts = pInts[keepTargs,:]
+ pTargs = pTargs[keepTargs]
+
+ ### Ensure expression and prior targets are in the same order
+ expTargInds = findall(in(pTargs), expressionData.tfaGenes)
+ targExp = expressionData.tfaGeneMat[expTargInds,:]
+ if !(pTargs == expressionData.tfaGenes[expTargInds])
+ println("Warnings, gene order in expression matrix and prior matrix do not match!!")
+ end
+
+ # Add data to object
+ priorData.pRegs = pRegs
+ priorData.pTargs = pTargs
+ priorData.priorMatrix = pInts
+ priorData.pRegsNoTfa = pRegsNoTfa
+ priorData.pTargsNoTfa = pTargsNoTfa
+ priorData.priorMatrixNoTfa = priorMatrixNoTfa
+ priorData.noPriorRegs = noPriorRegs
+ priorData.noPriorRegsMat = noPriorRegsMat
+ priorData.targExpression = targExp
+
+ else
+ error("Prior file not found.")
+ end
+ end
+
+ function calculateTFA!(priorData::PriorTFAData, expressionData::GeneExpressionData;
+ edgeSS = 0, zTarget::Bool = false, outputDir::Union{String, Nothing}=nothing)
+ priorMatrix = priorData.priorMatrix
+ targExp = priorData.targExpression
+ totTargs = size(priorMatrix, 1)
+ totPreds = size(priorMatrix, 2)
+ totConds = size(targExp, 2)
+
+ if zTarget
+ targExp = (targExp .- mean(targExp, dims=2)) ./ std(targExp, dims=2)
+ println("Target expression normalized (z-score per gene).")
+ end
+
+ if edgeSS > 0
+ tfas = zeros(Float64, edgeSS, totPreds, totConds)
+
+ for ss in 1:edgeSS
+ sPrior = zeros(Float64, totTargs, totPreds)
+
+ for col in 1:totPreds
+ currTargs = priorMatrix[:, col]
+ targInds = findall(!iszero, currTargs)
+ totCurrTargs = length(targInds)
+ ssampleSize = Int(ceil(0.63 * totCurrTargs))
+ ssample = rand(targInds, ssampleSize)
+
+ sPrior[ssample, col] = priorMatrix[ssample, col]
+ end
+ tfas[ss, :, :] = sPrior \ targExp
+ end
+
+ priorData.medTfas = median(tfas, dims=1)
+ println("Median from ", string(edgeSS), " subsamples used for prior-based TFA.")
+ else
+ # solves for X = Prior * TFA.
+ # TFA = argmin||priorMatrix * TFA - targExp||²
+ priorData.medTfas = priorMatrix \ targExp
+ println("No subsampling for prior-based TFA estimate.")
+ end
+
+ if outputDir !== nothing && outputDir !== ""
+
+ outputFile = joinpath(outputDir, "tfaMat.jld")
+ save_object(outputFile, priorData)
+
+ end
+ end
diff --git a/src/grn/AggregateNetworks.jl b/src/grn/AggregateNetworks.jl
new file mode 100755
index 0000000..38f4b5c
--- /dev/null
+++ b/src/grn/AggregateNetworks.jl
@@ -0,0 +1,301 @@
+"""
+ GRN
+
+This module provides functionality for building and combining Gene Regulatory Networks (GRNs).
+
+The GRN files are assumed to have the following columns:
+ TF, Gene, signedQuantile, Stability, Correlation, strokeVals, strokeWidth, inPrior
+Main function:
+- `combineGRNs` (internal): combine multiple GRN files with options for mean/max aggregation,
+ controlling edges per gene, and saving the combined network.
+
+Dependencies:
+- Uses `PriorUtils` for prior-related helper functions.
+- Includes internal utilities in `utilsGRN.jl`.
+"""
+
+# using ..DataUtils
+# using CSV, DataFrames, Statistics, StatsBase, Printf, Dates
+# using ArgParse
+
+# export combineGRNs
+# ────────────────────────────────────────────────────────────────
+# Helper: deterministic tie-breaker
+# ────────────────────────────────────────────────────────────────
+function pickRow(g, primary::Union{String,Symbol}; order::Symbol = :max)
+
+ """
+ pickRow(g, primary; order = :max)
+
+ Return a single row index from SubDataFrame `g` according to a deterministic
+ tie-breaking hierarchy.
+
+ Arguments
+ ─────────
+ g - SubDataFrame produced by `groupby`
+ primary - "Stability"/:Stability or "Quantile"/:Quantile
+ order - :max (default) or :min (optimise up or down)
+
+ Tie-breaking
+ ────────────
+ If primary = Stability : Stability → Quantile → |Correlation|
+ If primary = Quantile : Quantile → |Correlation|
+ """
+
+ prim = Symbol(primary) # normalise to Symbol
+
+ # 1) extreme value of the primary column
+ pVals = g[!, prim]
+ extreme = order === :max ? maximum(pVals) : minimum(pVals)
+ idxs = findall(==(extreme), pVals)
+ length(idxs) == 1 && return idxs[1]
+
+ # 2) additional tie-breakers
+ if prim == :Stability && "signedQuantile" in names(g)
+ qVals = g.signedQuantile[idxs]
+ extQ = order === :max ? maximum(qVals) : minimum(qVals)
+ idxs = idxs[findall(==(extQ), qVals)]
+ length(idxs) == 1 && return idxs[1]
+ # if still tied we fall through to correlation
+ end
+
+ if "Correlation" in names(g)
+ absCorr = abs.(g.Correlation[idxs])
+ bestC = maximum(absCorr) # always take largest |ρ|
+ idxs = idxs[findall(==(bestC), absCorr)]
+ end
+
+ return idxs[1] # deterministic fallback
+end
+
+
+# ────────────────────────────────────────────────────────────────
+# Main function
+# ────────────────────────────────────────────────────────────────
+function aggregateNetworks(nets2combine::Vector{String};
+ method::Union{String,Symbol} = :max,
+ meanEdgesPerGene::Int = 20,
+ useMeanEdgesPerGene::Bool = true,
+ outputDir::Union{String,Nothing} = nothing,
+ saveName::String = "")
+
+ """
+ # Arguments
+ - `nets2combine::Vector{String}`: A vector of GRN file names to combine.
+ - `method::Union{String,Symbol}`: Combination strategy — :max, :min, :mean, :meanQuantile, :maxQuantile.
+ - `meanEdgesPerGene::Int`: Mean number of edges per target gene.
+ - `outputDir::Union{String,Nothing}`: Directory to save results. Defaults to `nothing`.
+ - `saveName::String`: Optional filename prefix for saved files. Defaults to `""`.
+ useMeanEdgesPerGene: If `true`, selects `meanEdgesPerGene * uniqueGenes` edges globally.
+ If `false`, selects top `meanEdgesPerGene` per target gene.
+ """
+ combineOpt = string(method) # normalise Symbol/String → String for comparisons
+ useMeanEdgesPerGeneMode = useMeanEdgesPerGene
+ saveDir = outputDir
+ allDfs = DataFrame[]
+
+ # ──── 1. Read each file and then combine; create a rank column if not present ───────────────────────────
+ for file in nets2combine
+ # Adjust CSV reading as needed (TSV expected if tab sep)
+ df = CSV.read(file, DataFrame; delim='\t')
+ # println(size(df))
+
+ # # Ensure the required columns exist
+ # requiredCols = ["TF", "Gene", "signedQuantile", "Stability", "Correlation", "inPrior"]
+ # dfCols = names(df)
+ # for col in requiredCols
+ # if !(col in dfCols)
+ # error("Column $(col) not found in file $(file)")
+ # end
+ # end
+ # df = df[!, requiredCols]
+
+ # ensure numeric columns are Float64
+ for col in [:Stability, :Correlation]
+ if !(eltype(df[!, col]) <: AbstractFloat)
+ df[!, col] = parse.(Float64, string.(df[!, col]))
+ end
+ end
+
+ # fill in ranks ONLY if these columns do not already exist
+ if !("signedQuantile" in names(df)) #|| !(:Ranks in names(df))
+ st_ecdf = ecdf(df.Stability)
+ quantile = st_ecdf.(df.Stability)
+ df.signedQuantiles = quantile .* sign.(df.Correlation)
+ end
+ push!(allDfs, df)
+ end
+
+ # ──── 2. CONCATENATE AND AGGREGATE ACCORDING TO combineOpt ───────────────────────────
+ # Combine all dataframes vertically
+ combinedDf = vcat(allDfs...)
+ # Group by TF and Gene and aggregate according to the chosen method.
+ groupedDf = groupby(combinedDf, ["TF", "Gene"])
+ aggRows = NamedTuple[]
+
+ for g in groupedDf
+ # All rows in this group share the same TF, Gene and inPrior
+ tf, gene, dash = g.TF[1], g.Gene[1], g.inPrior[1]
+ # tf, gene, dash = g.TF[1], g.Gene[1], g.strokeDashArray[1]
+
+ if combineOpt == "mean"
+ c = mean(g.Correlation)
+ push!(aggRows, (TF = tf, Gene = gene, Stability = mean(g.Stability),
+ Correlation = c,
+ inPrior = dash))
+
+ elseif combineOpt == "max"
+ # idx = combineOpt == "max" ? argmax(g.Stability) : argmin(g.Stability)
+ idx = pickRow(g, "Stability", order = :max)
+ row = g[idx, :]
+ push!(aggRows, (TF = tf, Gene = gene, signedQuantile = row.signedQuantile, Stability = row.Stability,
+ Correlation = row.Correlation, inPrior = row.inPrior))
+
+ elseif combineOpt == "min"
+ # idx = combineOpt == "max" ? argmax(g.Stability) : argmin(g.Stability)
+ idx = pickRow(g, "Stability", order = :min)
+ row = g[idx, :]
+ push!(aggRows, (TF = tf, Gene = gene, Stability = row.Stability,
+ Correlation = row.Correlation, inPrior = row.inPrior))
+
+ elseif combineOpt == "meanQuantile"
+ c = mean(g.Correlation)
+ push!(aggRows, (TF = tf, Gene = gene, signedQuantile = mean(g.signedQuantile), Stability = mean(g.Stability),
+ Correlation = c,
+ inPrior = dash))
+
+ elseif combineOpt == "maxQuantile"
+ # idx = argmax(g.signedQuantile)
+ idx = pickRow(g, "signedQuantile", order = :max)
+ row = g[idx, :]
+ push!(aggRows, (TF = tf, Gene = gene, signedQuantile = row.signedQuantile, Stability = row.Stability, Correlation = row.Correlation,
+ inPrior = row.inPrior))
+ else
+ error("Invalid combineOpt: $(combineOpt).")
+ end
+ end
+
+ aggregatedDf = DataFrame(aggRows)
+
+
+ # ────── 3. Filter out extrememly weak correlations and compute new signed quantiles using broadcasting ──────────────────────
+ pcut = 0.01 # Correlation cutoff
+ aggregatedDf = filter(:Correlation => x -> abs(x) > pcut, aggregatedDf)
+ # ── 3b. Compute signedQuantile ─────────────────
+ if combineOpt in ("max","min","mean")
+ stability = aggregatedDf.Stability # Vector
+ F = ecdf(stability) # F(x) = proportion ≤ x
+ signed_q = F.(stability) .* sign.(aggregatedDf.Correlation)
+ aggregatedDf[!,:signedQuantile] = signed_q
+ end
+
+ # Checks:
+ # abs_q = abs.(aggregatedDf.signedQuantile)
+ # i_min = argmin(abs_q)
+ # row_min = aggregatedDf[i_min, :]
+
+ # sort for reproducibility
+ sort!(aggregatedDf, :signedQuantile, by= abs, rev = true)
+ # sort!(aggregatedDf, combineOpt in ("meanQuantile", "maxQuantile") ? :signedQuantile : :Stability,
+ # by = combineOpt in ("meanQuantile", "maxQuantile") ? abs : identity, rev = true)
+
+ # ────── 5. choose top meanEdgesPerGene * nGenes edges ──────────────────────
+
+ if useMeanEdgesPerGeneMode
+ println("Selecting the top (meanEdgesPerGene * unique targets) edges")
+ # Select the top `topEdgeCount = meanEdgesPerGene * uniqueGenes` & compute quantiles for all rankings
+ uniqueGenes = length(unique(aggregatedDf.Gene)) # current number of unique targets
+ topEdgeCount = round(Int, meanEdgesPerGene * uniqueGenes) # Compute the number of edges to select
+ selectionIndices = 1:min(topEdgeCount, nrow(aggregatedDf))
+ else
+ println("Selecting the top meanEdgesPerGene edges per target gene")
+ selectionIndices = firstNByGroup(aggregatedDf.Gene, meanEdgesPerGene)
+ end
+ selectedDf = aggregatedDf[selectionIndices, :]
+ println("Selected $(length(selectionIndices)) edges using meanEdgesPerGene = $meanEdgesPerGene.")
+
+ # # ────── 6. Generate strokeWidth and colors
+ # # ────── Color calculations for JP-Gene-Viz
+ # minRank, maxRank = extrema(selectedDf.Stability)
+ # rankRange = max(maxRank - minRank, eps()) # prevent division by zero
+ # strokeWidth = 1 .+ (selectedDf.Stability .- minRank) ./ rankRange
+ # # Color mapping
+ # medRed = [228, 26, 28]
+ # lightGrey = [217, 217, 217]
+ # strokeVals = map(abs.(selectedDf.Correlation)) do corr
+ # color = corr * medRed .+ (1 - corr) * lightGrey
+ # "rgb($(floor(Int, round(color[1]))),$(floor(Int, round(color[2]))),$(floor(Int, round(color[3]))))"
+ # # "rgb(" * string(floor(Int, round(color[1]))) * "," * string(floor(Int, round(color[2]))) * "," * string(floor(Int, round(color[3]))) * ")"
+ # end
+ # selectedDf[!, :strokeVals] = strokeVals
+ # selectedDf[!, :strokeWidth] = strokeWidth
+
+
+ # ── 7. save to disk if saveDir is provided ──────────────────────────────────────────
+ if saveDir !== nothing
+ mkpath(saveDir)
+ tag = isnothing(saveName) || isempty(saveName) ? combineOpt : "$(saveName)_$(combineOpt)"
+ # CSV.write(joinpath(saveDir, "$(tag)_aggregated.tsv"), aggregatedDf; delim = '\t')
+ CSV.write(joinpath(saveDir, "combined_$(tag).tsv"), selectedDf; delim = '\t')
+
+ wideDf = convertToWide(selectedDf[ :,["TF","Gene","signedQuantile"]]; indices=(2,1,3))
+ wideDf = coalesce.(wideDf, 0)
+ wideFile = joinpath(saveDir, "combined_$(tag)_sp.tsv")
+ writeTSVWithEmptyFirstHeader(wideDf, wideFile; delim = '\t')
+ println("Files written under ", saveDir)
+ end
+
+ return selectedDf
+end
+
+# -----------------------
+# main procedure
+# -----------------------
+# This section executes if the script is run directly. It will not execute if called/imported into another script.
+if abspath(PROGRAM_FILE) == @__FILE__
+ using ArgParse
+
+ # Define argument parser
+ s = ArgParseSettings()
+ @add_arg_table s begin
+ "--combineOpt"
+ help = "Combination option: max, min, or mean (default: mean)"
+ arg_type = String
+ default = "mean"
+ "--meanEdgesPerGene"
+ help = "Mean number of edges per unique gene."
+ arg_type = Int
+ required = true
+ "--useMeanEdgesPerGeneMode"
+ help = "controls whether edge selection is done per-group or globally. If true, selects length(unique(targs)) * meanEdgesPerGene edges.
+ If `false`, selects selects the top meanEdgesPerGene edges per target gene."
+ arg_type = Bool
+ default = true
+ "--saveDir"
+ help = "Directory in which to save output files."
+ arg_type = String
+ default = ""
+ "--saveName"
+ help = "Base name for the saved file."
+ arg_type = String
+ default = ""
+ "files..."
+ help = "List of GRN files (TSV format) to combine."
+ end
+
+ parsedArgs = parse_args(s)
+ fileList = parsedArgs["files"]
+ combineOpt = parsedArgs["combineOpt"]
+ meanEdgesPerGene = parsedArgs["meanEdgesPerGene"]
+ useMeanEdgesPerGeneMode = parsedArgs["useMeanEdgesPerGeneMode"]
+ saveDir = isempty(parsedArgs["saveDir"]) ? nothing : parsedArgs["saveDir"]
+ saveName = parsedArgs["saveName"]
+
+ @printf("Combining %d files with combination option: %s\n", length(fileList), combineOpt)
+
+ selectedDf = aggregateNetworks(fileList; method=combineOpt, meanEdgesPerGene=meanEdgesPerGene, outputDir=saveDir, saveName=saveName)
+
+ # Optionally, print summary
+ @printf("Final network has %d edges spanning %d unique genes.\n", nrow(selectedDf), length(unique(selectedDf.Gene)))
+end
+
diff --git a/src/grn/BuildGRN.jl b/src/grn/BuildGRN.jl
new file mode 100755
index 0000000..3ed4753
--- /dev/null
+++ b/src/grn/BuildGRN.jl
@@ -0,0 +1,377 @@
+# mutable struct BuildGrn
+# networkStability::Matrix{Float64}
+# lambda::Union{Float64, Vector{Float64}}
+# targs::Vector{String}
+# regs::Vector{String}
+# rankings::Vector{Float64}
+# signedQuantile::Vector{Float64}
+# partialCorrelation::Vector{Float64}
+# inPrior::Vector{String}
+# networkMat::Matrix{Any}
+# networkMatSubset::Matrix{Any}
+# inPriorVec::Vector{Float64}
+# betas::Matrix{Float64}
+# function BuildGrn()
+# return new(
+# Matrix{Float64}(undef, 0, 0), # networkStability
+# 0.0, # lambda
+# [], # targs
+# [], # regs
+# [], # rankings
+# [], # signedQuantile
+# [], # partialCorrelation
+# [], # inPrior
+# Matrix{Float64}(undef, 0, 0), # networkMat
+# Matrix{Float64}(undef, 0, 0), # networkMatSubset
+# [], # inPriorVec
+# Matrix{Float64}(undef, 0, 0) # betas
+# )
+# end
+# end
+
+function chooseLambda!(grnData::GrnData, buildGrn::BuildGrn; instabilityLevel = "Gene", targetInstability = 0.05)
+ totLambdas, totNetGenes, totNetTfs = size(grnData.stabilityMat)
+ networkStability = zeros(totNetGenes, totNetTfs)
+ networkStability = convert(Matrix{Float64}, networkStability)
+ # lambdaRange = reverse(grnData.lambdaRange)
+ lambdaRange = grnData.lambdaRange
+ betas = zeros(totNetGenes, totNetTfs)
+ lambdaTrack = []
+ ## transform StARS instabilities into stabilities
+ if instabilityLevel == "Gene"
+ totMins = 0
+ totMaxs = 0
+ for targ = 1:totNetGenes
+ currInstabs = grnData.geneInstabilities[targ, :]
+ devs = abs.(currInstabs .- targetInstability)
+ globalMin = minimum(devs)
+ minInds = findall(x -> x == globalMin[1], devs) # globalMin
+ minInd = minInds[end]
+
+ lambdaRangeGene = grnData.lambdaRangeGene[targ]
+ push!(lambdaTrack, lambdaRangeGene[minInd])
+
+ networkStability[targ, :] = grnData.stabilityMat[minInd, targ, :]
+ betas[targ, :] = grnData.betas[targ, :, minInd]
+
+ if minInd == 1
+ totMins += 1
+ elseif minInd == totLambdas
+ totMaxs += 1
+ end
+ end
+ if totMins > 0
+ println("Target instability reached at minimum lambda for ", string(totMins), " gene(s)")
+ end
+ if totMaxs > 0
+ println("Target instability reached at maximum lambda for ", string(totMaxs), " gene(s)")
+ end
+
+ elseif instabilityLevel == "Network"
+ println("Network instabilities detected.")
+ # find the single lambda corresponding to the cutoff
+ devs = abs.(grnData.netInstabilities .- targetInstability)
+ globalMin = findmin(devs)
+ minInds = first(Tuple(globalMin[2]))
+ globalMin = globalMin[1]
+ # take the largest lambda that is closest to targetInstability
+ minInd = minInds[end]
+ if minInd == 1
+ println("Minimum lambda was used for maximum instability ", string(grnData.netInstabilities[minInd]), " to reach target cut = ", string(targetInstability), ".")
+ elseif minInd == totLambdas
+ println("Maximum lambda was used for minimum instability ", string(grnData.netInstabilities[minInd]), " to reach target cut = ", string(targetInstability), ".")
+ end
+ networkStability[:, :] = grnData.stabilityMat[minInd, :, :]
+ betas = grnData.betas[:, :, minInd]
+ lambdaTrack = lambdaRange[minInds]
+
+ else
+ error("instabSource not recognized, should be either Gene or Network.")
+ end
+
+ # ssMatrix has infinity entries to mark illegal TF-gene interactions
+ # (e.g., TF mRNA TFA cannot be used to predict TF gene expression)
+ networkStabilityVec = networkStability[:]
+ networkStabilityVec[isinf.(networkStabilityVec)] .= 0
+ networkStability = reshape(networkStabilityVec, totNetGenes, totNetTfs)
+
+ if length(lambdaTrack) == 1
+ lambdaTrack = convert(Float64, lambdaTrack)
+ else
+ lambdaTrack = convert(Vector{Float64}, lambdaTrack)
+ end
+
+ buildGrn.lambda = lambdaTrack
+ buildGrn.networkStability = networkStability
+ buildGrn.betas = betas
+end
+
+# function firstNByGroup(vect::AbstractVector, N::Integer)
+# """
+# firstNByGroupIndices(vec::AbstractVector, N::Integer)
+
+# Return the indices of the first `N` occurrences of each unique element in `vec`.
+
+# This is useful when you want to subset another array based on limited occurrences per group.
+
+# # Arguments
+# - `vec`: A vector of values to group by (e.g., transcription factors).
+# - `N`: The maximum number of elements to select per unique group.
+
+# # Returns
+# - A vector of indices corresponding to the first `N` entries per group.
+# """
+# seen = Dict{eltype(vect), Int}()
+# idxs = Int[]
+# for (i, v) in enumerate(vect)
+# seen[v] = get(seen, v, 0) + 1
+# if seen[v] <= N
+# push!(idxs, i)
+# end
+# end
+# return idxs
+# end
+
+function rankEdges!(expressionData::GeneExpressionData, tfaData::PriorTFAData, grnData::GrnData, buildGrn::BuildGrn;
+ mergedTFsData::Union{mergedTFsResult, Nothing}=nothing,
+ useMeanEdgesPerGeneMode = true,
+ meanEdgesPerGene = 20,
+ correlationWeight = 1,
+ outputDir::Union{String, Nothing}=nothing)
+
+ if isempty(outputDir)
+ error("You must provide a non-empty outputDir to save the results.")
+ end
+
+ predictorMat = grnData.predictorMat
+ allPredictors = grnData.allPredictors
+
+ totLambdas, totNetGenes, totNetTfs = size(grnData.stabilityMat)
+ totInts = totNetGenes * totNetTfs
+ targs = repeat(expressionData.targGenes, totNetTfs, 1)
+ regs = vec(repeat(permutedims(allPredictors), totNetGenes,1))
+ rankTmp = buildGrn.networkStability[:]
+ keepInds = findall(x -> x != 0 && x != Inf, rankTmp) # keep nonzero, remove infinite values (e.g., corresponding to TF-TF edges when TF mRNA used for TFA)
+ rankTmp = rankTmp[keepInds]
+ inds = reverse(sortperm(rankTmp))
+ rankings = sort(rankTmp,rev=true)
+ regs = regs[keepInds[inds]]
+ targs = targs[keepInds[inds]]
+ totInfInts = length(rankings)
+
+ if useMeanEdgesPerGeneMode
+ totQuantEdges = length(unique(targs)) * meanEdgesPerGene
+ selectionIndices = 1:min(totQuantEdges, totInfInts)
+ else
+ selectionIndices = firstNByGroup(targs, meanEdgesPerGene)
+ end
+
+ # Compute quantiles
+ F = ecdf(rankings[selectionIndices]) # uses the StatsBase package
+ quantiles = F.(rankings[selectionIndices])
+
+ ## take what's in the meanEdgesPerGene network and get partial correlations
+ allCoefs = zeros(totNetGenes,totNetTfs)
+ allQuants = zeros(totNetGenes,totNetTfs)
+ allStabsTest = buildGrn.networkStability
+ keptTargs = permutedims(targs[selectionIndices])
+ uniTargs = unique(keptTargs)
+ totUniTargs = length(uniTargs)
+ tfsPerGene = zeros(totUniTargs,1)
+
+ for targ = ProgressBar(1:totUniTargs)
+ currTarg = uniTargs[targ]
+ targRankInds = last.(Tuple.(findall(x -> x == currTarg, keptTargs)))
+ currRegs = regs[targRankInds]
+
+ # --- Deduplicate currRegs while keeping aligned targRankInds ---
+ uniCurrRegsInds = firstNByGroup(currRegs, 1) # This reurns the index where each unique currRegs first appeared
+ uniCurrRegs = currRegs[uniCurrRegsInds]
+ targRankInds = targRankInds[uniCurrRegsInds] # Ensure the target corresponding to the duplicated Regulator is removed
+ # --- Get target index in full gene list
+ targInd = last.(Tuple.(findall(x -> x == currTarg, expressionData.targGenes)))
+ tfsPerGene[targ] = Int.(length(targRankInds))
+ # tfsPerGene = Int.(tfsPerGene)
+
+ # --- Match predictors ---
+ matchedRegs = intersect(allPredictors,currRegs)
+ # inds = findall(in(matchedRegs), allPredictors)
+ # regressIndsMat = first.(inds)
+ regressIndsMat = findall(in(matchedRegs), allPredictors)
+ rankVecInds = findall(in(matchedRegs),uniCurrRegs)
+ # --- Prepare input matrix: target + predictors ---
+ currTargVals = vec(transpose(grnData.responseMat[targInd,:]))
+ currPredVals = transpose(predictorMat[regressIndsMat,:])
+ combTargPreds = vcat(currTargVals', currPredVals')'
+ combTargPreds = permutedims(combTargPreds')
+ prho = partialCorrelationMat(combTargPreds; first_vs_all = true)
+ prho = prho[2:end]
+ prho = vec(prho)
+
+ if length(findall(x -> x == NaN, prho)) == 0 # make sure there weren't too many edges,
+ allCoefs[targInd,regressIndsMat] = prho
+ else
+ println(currTarg, " pcorr was singular, # TFs = ", string(length(regressIndsMat)))
+ end
+
+ allQuants[targInd,regressIndsMat] = quantiles[targRankInds[rankVecInds]]
+ allStabsTest[targInd,regressIndsMat] = allStabsTest[targInd,regressIndsMat] + (correlationWeight .* transpose(round.(abs.(prho),digits = 4)))
+ end
+
+ inPriorMat = sign.(abs.(grnData.priorMatProcessed))
+
+ mergeTfLocVec = zeros(totNetTfs) # for keeping track of merged TFs (needed for partial correlation calculation)
+ if (mergedTFsData !== nothing) &&
+ (mergedTFsData.mergedPrior !== nothing) &&
+ (mergedTFsData.mergedTFMap !== nothing)
+
+ mergedTFs = mergedTFsData.mergedTFMap[:, 1]
+ individualTFs = mergedTFsData.mergedTFMap[:, 2]
+ totMerged = length(individualTFs)
+
+ rmInds = [] # remove merged TFs from regulators
+ addRegs = []
+ addInts = []
+ addCoefs = []
+ addPMat = []
+ addPredMat = []
+ addLoc = []
+ addQuants = []
+ for mind = 1:totMerged
+ mTf = mergedTFs[mind]
+ inputLocs = findall(x -> x == mTf,allPredictors)
+ totMInts = length(inputLocs) # number of interactions for merged TF
+ rmInds = [rmInds; inputLocs];
+ if length(inputLocs) > 0
+ indTfs = intersect(permutedims(split(individualTFs[mind],", ")),tfaData.pRegsNoTfa); # intersect ensures that TF was a potential regulator (e.g., based on gene expression)
+ totIndTfs = length(indTfs)
+ for indt = 1:totIndTfs
+ indTf = indTfs[indt]
+ append!(addRegs, fill(indTf, 1))
+ append!(addInts, allStabsTest[:,inputLocs])
+ append!(addPMat, inPriorMat[:,inputLocs])
+ append!(addQuants, allQuants[:,inputLocs])
+ append!(addCoefs, allQuants[:,inputLocs])
+ append!(addPredMat, predictorMat[inputLocs,:])
+ append!(addLoc, mind)
+ end
+ end
+ end
+ println("Total of ", string(length(rmInds)), " TFs expanded.")
+ keepInds = setdiff(1:totNetTfs,rmInds)
+ # remove merged TFs and add individual TFs
+ if length(addRegs) > 0
+ allPredictors = collect(allPredictors[keepInds])
+ append!(allPredictors, addRegs)
+ allStabsTest = hcat(allStabsTest[:,keepInds], reshape(addInts, size(allStabsTest)[1], Int(size(addInts)[1] / size(allStabsTest)[1])))
+ allCoefs = hcat(allCoefs[:,keepInds], reshape(addCoefs, size(allStabsTest)[1], Int(size(addInts)[1] / size(allStabsTest)[1])))
+ allQuants = hcat(allQuants[:,keepInds], reshape(addQuants, size(allStabsTest)[1], Int(size(addInts)[1] / size(allStabsTest)[1])))
+ inPriorMat = hcat(inPriorMat[:,keepInds], reshape(addPMat, size(inPriorMat)[1], Int(size(addPMat)[1] / size(inPriorMat)[1])))
+ predictorMat = vcat(predictorMat[keepInds, :], reshape(addPredMat, Int((size(addPredMat)[1]) / size(predictorMat)[2]), size(predictorMat)[2]))
+ append!(mergeTfLocVec, addLoc)
+ end
+ else
+ println("No merged TFs file found.")
+ end
+
+ ## re-rank based on possibly de-merged TFs
+ rankings = allStabsTest[:]
+ coefVec = allCoefs[:]
+
+
+ inPriorVec = inPriorMat[:]
+ totNetTfs = length(allPredictors)
+ totInts = totNetGenes * totNetTfs
+ targs = repeat((expressionData.targGenes),totNetTfs,1)
+ regs1 = repeat(permutedims(allPredictors),totNetGenes,1)
+ regs = reshape(regs1,totInts,1)
+
+ rankings = convert(Vector{Float64}, rankings)
+ coefVec = convert(Vector{Float64}, coefVec)
+ inPriorVec = convert(Vector{Float64}, inPriorVec)
+ predictorMat = convert(Matrix{Float64}, predictorMat)
+ allStabsTest = convert(Matrix{Float64}, allStabsTest)
+ allCoefs = convert(Matrix{Float64}, allCoefs)
+ allQuants = convert(Matrix{Float64}, allQuants)
+ inPriorMat = convert(Matrix{Float64}, inPriorMat)
+
+ ## only keep nonzero rankings
+ keepRankings = findall(x -> x != 0 && x != Inf, rankings)
+ indsMerged = sortperm(rankings[keepRankings])
+ indsMerged = reverse(indsMerged)
+
+ # update info sources
+ rankings = rankings[keepRankings[indsMerged]]
+ coefVec = coefVec[keepRankings[indsMerged]]
+ inPriorVec = inPriorVec[keepRankings[indsMerged]]
+ regs = regs[keepRankings[indsMerged]]
+ targs = targs[keepRankings[indsMerged]]
+ totInfInts = length(rankings)
+
+ # totQuantEdges = length(unique(targs))*meanEdgesPerGene
+ # Compute quantiles
+ F = ecdf(rankings)
+ quantiles = F.(rankings)
+
+ # Compute signedQuantiles
+ signedQuantile = sign.(coefVec) .* quantiles
+
+ ## ------- Color calculations for JP-Gene-Viz
+ # Compute strokeWidth and colors
+ minRank, maxRank = extrema(rankings)
+ rankRange = max(maxRank - minRank, eps()) # prevent division by zero
+ # Dash Styling
+ strokeDashArray = ifelse.(inPriorVec .!= 0, "Yes", "No")
+ networkMatrix = hcat(regs, targs, signedQuantile, rankings, coefVec, strokeDashArray)
+ if useMeanEdgesPerGeneMode
+ totQuantEdges = length(unique(targs)) * meanEdgesPerGene
+ # selectionIndices = 1:totQuantEdges
+ selectionIndices = 1:min(totQuantEdges, totInfInts)
+ else
+ selectionIndices = firstNByGroup(targs, meanEdgesPerGene)
+
+ end
+
+ networkMatrixSubset = hcat(regs[selectionIndices], targs[selectionIndices], signedQuantile[selectionIndices],
+ rankings[selectionIndices], coefVec[selectionIndices], strokeDashArray[selectionIndices]
+ )
+
+ buildGrn.regs = regs
+ buildGrn.targs = targs
+ buildGrn.signedQuantile = signedQuantile
+ buildGrn.rankings = rankings
+ buildGrn.partialCorrelation = coefVec
+ buildGrn.inPrior = strokeDashArray
+ buildGrn.networkMat = networkMatrix
+ buildGrn.networkMatSubset = networkMatrixSubset
+ buildGrn.inPriorVec = inPriorVec
+ buildGrn.mergeTfLocVec = mergeTfLocVec # just added
+
+
+ if outputDir !== nothing && outputDir !== ""
+ outputFile = joinpath(outputDir, "grnOutMat.jld")
+ save_object(outputFile, buildGrn)
+ end
+end
+
+# function saveData(expressionData::GeneExpressionData, tfaData::PriorTFAData, grnData::GrnData, buildGrn::BuildGrn, outputDir::String, fileName::String)
+# output = outputDir * "/" * fileName
+# @save output expressionData tfaData grnData buildGrn
+# end
+
+# function writeNetworkTable!(buildGrn::BuildGrn; outputDir::String, networkName::Union{String, Nothing}=nothing)
+# local baseName = (networkName === nothing || isempty(networkName)) ? "edges" : networkName * "edges"
+
+# outputFile = joinpath(outputDir, baseName * ".tsv")
+# colNames = "TF\tGene\tsignedQuantile\tStability\tCorrelation\tinPrior\n"
+# open(outputFile, "w") do io
+# write(io, colNames)
+# writedlm(io, buildGrn.networkMat)
+# end
+
+# outputFileSubset = joinpath(outputDir, baseName * "_subset.tsv")
+# open(outputFileSubset, "w") do io
+# write(io, colNames)
+# writedlm(io, buildGrn.networkMatSubset)
+# end
+# end
diff --git a/src/grn/PrepareGRN.jl b/src/grn/PrepareGRN.jl
new file mode 100755
index 0000000..d3cb15f
--- /dev/null
+++ b/src/grn/PrepareGRN.jl
@@ -0,0 +1,454 @@
+"""
+GRN
+
+Functions:
+- preparePredictorMat!
+- preparePenaltyMatrix!
+- constructSubsamples
+- bstarsWarmStart
+- bstartsEstimateInstability
+
+Dependencies:
+"""
+
+# mutable struct GrnData
+# predictorMat::Matrix{Float64}
+# penaltyMat::Matrix{Float64}
+# allPredictors::Vector{String}
+# subsamps::Matrix{Int64}
+# responseMat::Matrix{Float64}
+# maxLambdaNet::Float64
+# minLambdaNet::Float64
+# minLambdas::Matrix{Float64}
+# maxLambdas::Matrix{Float64}
+# netInstabilitiesUb::Vector{Float64}
+# netInstabilitiesLb::Vector{Float64}
+# instabilitiesUb::Matrix{Float64}
+# instabilitiesLb::Matrix{Float64}
+# netInstabilities::Vector{Float64}
+# geneInstabilities::Matrix{Float64}
+# lambdaRange::Vector{Float64}
+# lambdaRangeGene::Vector{Vector{Float64}}
+# stabilityMat::Array{Float64}
+# priorMatProcessed::Matrix{Float64}
+# betas::Array{Float64,3}
+# function GrnData()
+# return new(
+# Matrix{Float64}(undef, 0, 0), # predictorMat
+# Matrix{Float64}(undef, 0, 0), # penaltyMat
+# [], # allPredictors
+# Matrix{Int64}(undef, 0, 0), # subsamps
+# Matrix{Int64}(undef, 0, 0), # responseMat
+# 0.0, # maxLambdasNet
+# 0.0, # minLambdasNet
+# Matrix{Int64}(undef, 0, 0), # minLambdas
+# Matrix{Int64}(undef, 0, 0), # maxLambdas
+# [], # netInstabilitiesUb
+# [], # netInstabilitiesLb
+# Matrix{Int64}(undef, 0, 0), # instabilitiesUb
+# Matrix{Int64}(undef, 0, 0), # instabilitiesLb
+# [], # netInstabilities
+# Matrix{Int64}(undef, 0, 0), # geneInstabilities
+# [], # lambdaRange
+# Vector{Vector{Float64}}(undef, 0), # lambdaRangesGene
+# Matrix{Int64}(undef, 0, 0), # stabilityMat
+# Matrix{Float64}(undef, 0, 0), # priorMatProcessed
+# Array{Float64,3}(undef, 0, 0, 0) # betas
+# )
+# end
+# end
+
+function preparePredictorMat!(grnData::GrnData, expressionData::GeneExpressionData, priorData::PriorTFAData; tfaOpt::String = "")
+ if tfaOpt != ""
+ println("noTfa option")
+ pRegs = priorData.pRegsNoTfa;
+ pTargs = priorData.pTargsNoTfa;
+ priorMatrix = priorData.priorMatrixNoTfa;
+ else
+ pRegs = priorData.pRegs;
+ pTargs = priorData.pTargs;
+ priorMatrix = priorData.priorMatrix;
+ end
+ # TFs in potRegs that have expression data that arent in prior. Get the index for these TFs
+ # in the TF expression matrix
+ uniNoPriorRegs = setdiff(expressionData.potRegsmRNA, pRegs)
+ uniNoPriorRegInds = findall(in(uniNoPriorRegs), expressionData.potRegsmRNA)
+
+ # allPredictors include both the pRegs and the potRegs that wernt in prior
+ allPredictors = vcat(pRegs, uniNoPriorRegs)
+ totPreds = length(allPredictors)
+
+ # Create new prior matrix that contains target genes in the same order as targGenes. If using
+ # TFA, missing TFs not in prior that we have expression data for
+ targGeneInds = findall(in(pTargs), expressionData.targGenes)
+ priorGeneInds = findall(in(expressionData.targGenes), pTargs)
+ totTargGenes = length(expressionData.targGenes)
+ totPRegs = length(pRegs)
+ priorMat = zeros(totTargGenes,totPreds)
+ priorMat[targGeneInds,1:totPRegs] = priorMatrix[priorGeneInds,:]
+
+ # # predictorMat will be TFA (when available) and TFmRNA when TFA not available
+ # predictorMat = [priorData.medTfas; expressionData.potRegMatmRNA[uniNoPriorRegInds,:]]
+ # # If not using TFA, just set predictorMat to mRNA
+ # if tfaOpt != "" # use the mRNA levels of TFs
+ # currPredMat = zeros(totPreds,length(expressionData.cellLabels))
+ # for prend = 1:totPreds
+ # prendInd = findall(x -> x==allPredictors[prend],expressionData.potRegsmRNA)
+ # currPredMat[prend,:] = expressionData.potRegMatmRNA[prendInd,:]
+ # end
+ # predictorMat = currPredMat
+ # println("TF mRNA used.")
+ # end
+
+ # predictorMat will be TFA (when available) and TFmRNA when TFA not available
+ if tfaOpt == ""
+ # TFA: stack TFA estimates on top of mRNA for TFs not in prior
+ predictorMat = [priorData.medTfas; expressionData.potRegMatmRNA[uniNoPriorRegInds,:]]
+ else
+ # If not using TFA: use mRNA for all predictors
+ currPredMat = zeros(totPreds, length(expressionData.cellLabels))
+ for prend = 1:totPreds
+ prendInd = findall(x -> x == allPredictors[prend], expressionData.potRegsmRNA)
+ currPredMat[prend, :] = expressionData.potRegMatmRNA[prendInd, :]
+ end
+ predictorMat = currPredMat
+ println("TF mRNA used.")
+ end
+ grnData.predictorMat = predictorMat
+ grnData.allPredictors = allPredictors
+ grnData.responseMat = expressionData.targGeneMat
+ grnData.priorMatProcessed = priorMat
+end
+
+
+function preparePenaltyMatrix!(expressionData::GeneExpressionData, grnData::GrnData;
+ priorFilePenalties::Vector{String} = String[],
+ lambdaBias::Vector{Float64} = [0.5],
+ tfaOpt::String = "")
+ #1. Update Penalty Matrix
+ # Create a dictionary to store the minimum lambda for each interaction
+ minLambdaDict = Dict{Tuple{String,String}, Float64}()
+ penaltyMatrix = ones(length(expressionData.targGenes),length(grnData.allPredictors))
+
+ # Iterate through each prior file and its associated lambda
+ for (filePath, lambda) in zip(priorFilePenalties, lambdaBias)
+ # Read the prior file
+ priorData = readdlm(filePath)
+ # priorData = CSV.read(filePath, DataFrame; delim = "\t")
+
+ # Extract gene names and TF names from the prior file
+ priorGenes = priorData[2:end, 1]
+ # Get TF names, filtering out any empty or missing entries
+ priorTFs = filter(tf -> !ismissing(tf) && !isempty(string(tf)), priorData[1, :])
+
+ # Create indices mapping for faster lookup
+ geneToIdx = Dict(gene => i for (i, gene) in enumerate(expressionData.targGenes))
+ tfToIdx = Dict(tf => i for (i, tf) in enumerate(grnData.allPredictors))
+ # Process the interactions
+ for (i, gene) in enumerate(priorGenes)
+ for (j, tf) in enumerate(priorTFs)
+ # println(" Trying ($gene, $tf): ")
+ if priorData[i+1, j+1] != 0 && haskey(geneToIdx, gene) && haskey(tfToIdx, tf)
+ interaction = (gene, tf)
+ if !haskey(minLambdaDict, interaction) || lambda < minLambdaDict[interaction]
+ minLambdaDict[interaction] = lambda
+ end
+ end
+ end
+ end
+ end
+
+ # Apply the penalties using the minimum lambda values
+ for ((gene, tf), minLambda) in minLambdaDict
+ geneIdx = findfirst(==(gene), expressionData.targGenes)
+ tfIdx = findfirst(==(tf), grnData.allPredictors)
+ penaltyMatrix[geneIdx, tfIdx] = minLambda
+ end
+
+ # 2.
+ totPreds = length(grnData.allPredictors)
+ if tfaOpt !== ""
+ ## set lambda penalty to infinity for positive feedback edges where TF
+ # mRNA levels serves both as gene expression and TFA estimate
+ for pr = 1:totPreds # Changed this from length(expressionData.potRegs) to length(grnData.allPredictors)
+ targInd = findall(x -> x==grnData.allPredictors[pr], expressionData.targGenes)
+ if length(targInd) > 0 # set lambda penalty to infinity, avoid predicting a TF's mRNA based on its own mRNA level
+ penaltyMatrix[targInd,pr] .= Inf # i.e., target gene is its own predictor
+ end
+ end
+ else # have to set prior inds to zero for TFs in TFA that don't have prior info
+ for pr = 1:totPreds # Changed this from expressionData.potRegs to grnData.allPredictors
+ if sum(abs.(grnData.priorMatProcessed[:,pr])) == 0 # we have no target edges to estimate TF's TFA
+ targInd = findall(x -> x==grnData.allPredictors[pr], expressionData.targGenes)
+ if length(targInd) > 0 # And TF is in the predictor set
+ penaltyMatrix[targInd,pr] .= Inf
+ end
+ end
+ end
+ end
+
+ grnData.penaltyMat = penaltyMatrix
+end
+
+
+function constructSubsamples(expressionData::GeneExpressionData, grnData::GrnData;
+ leaveOutSampleList::Union{Vector{Vector{String}}, Nothing}=nothing,totSS = 200, subsampleFrac = 0.63)
+ totSamps = length(expressionData.cellLabels)
+
+ if !(leaveOutSampleList in (nothing, ""))
+ println("Leave-out set detected: ", leaveOutSampleList)
+ # get leave-out set of samples
+ fin = open(leaveOutSampleList)
+ C = readdlm(fin,skipstart=0)
+ C = convert(Matrix{String}, C)
+ close(fin)
+ testInds = Tuple.(findall(in(C), expressionData.cellLabels))
+ testInds = first.(testInds)
+ trainInds = setdiff(1:totSamps,testInds)
+ else
+ println("Full gene expression matrix used.")
+ trainInds = 1:totSamps # all training samples used
+ testInds = []
+ end
+
+ subsampleSize = floor(subsampleFrac*length(trainInds))
+ subsampleSize = convert(Int64, subsampleSize)
+ # get subsamp indices
+ subsamps = zeros(totSS,subsampleSize)
+ for ss = 1:totSS
+ randSubs = randperm(totSamps)
+ randSubs = randSubs[1:subsampleSize]
+ subsamps[ss,:] = randSubs
+ end
+ subsamps = convert(Matrix{Int}, subsamps)
+ grnData.subsamps = subsamps
+end
+
+function bstarsWarmStart(expressionData::GeneExpressionData, tfaData::PriorTFAData, grnData::GrnData; minLambda = 0.01, maxLambda = 0.5, totLambdasBstars = 20, totSS = 5, targetInstability = 0.05, zTarget = false)
+ # Determine the lambda levels to test
+ lambdaRange = collect(range(minLambda, stop = maxLambda, length = totLambdasBstars))
+ #lamLog10step = 1/totLambdasBstars
+ #logLamRange = log10(minLambda):lamLog10step:log10(maxLambda)
+ #lambdaRange = 10 .^ (logLamRange)
+ lambdaRange = reverse(lambdaRange)
+ totResponses = size(grnData.responseMat)[1]
+
+ instabilitiesLb = zeros(totResponses,totLambdasBstars)
+ instabilitiesUb = zeros(totResponses,totLambdasBstars)
+ minLambdas = zeros(totResponses,1)
+ maxLambdas = zeros(totResponses,1)
+
+ responsePredInds = Vector{Vector{Int}}(undef,0)
+ for res = 1:totResponses
+ currWeights = grnData.penaltyMat[res,:]
+ push!(responsePredInds,findall(x -> x!=Inf, currWeights))
+ end
+
+ netInstabilitiesLb = zeros(totResponses, totLambdasBstars)
+ netInstabilitiesUb = zeros(totResponses, totLambdasBstars)
+ totEdges = zeros(totResponses) # denominator for network Instabilities
+
+ theta2save = Vector{Matrix{Float64}}(undef, totResponses) # Debugging #
+
+ Threads.@threads for res in ProgressBar(1:totResponses) # can be a parfor loop
+ predInds = responsePredInds[res]
+ currPredNum = length(predInds)
+ penaltyFactor = grnData.penaltyMat[res, predInds]
+ totEdges[res] = currPredNum
+ ssVals = zeros(totLambdasBstars,currPredNum)
+ for ss = 1:totSS
+ subsamp = grnData.subsamps[ss,:]
+ dt = fit(ZScoreTransform, grnData.predictorMat[predInds, subsamp], dims=2)
+ currPreds = transpose(StatsBase.transform(dt, grnData.predictorMat[predInds, subsamp]))
+ if zTarget
+ dt = fit(ZScoreTransform, grnData.responseMat[res, subsamp], dims=1)
+ currResponses = StatsBase.transform(dt, grnData.responseMat[res, subsamp])
+ else
+ currResponses = grnData.responseMat[res, subsamp]
+ end
+ lsoln = glmnet(currPreds, currResponses, penalty_factor = penaltyFactor, lambda = lambdaRange, alpha = 1.0)
+ currBetas = lsoln.betas # flip so that the lambdas are increasing
+ # ssVals += abs.(sign.(currBetas))'
+ ssVals .+= abs.(sign.(currBetas))'
+ end
+ theta2 = (1/totSS)*ssVals # empirical edge probability
+ theta2save[res] = theta2 #For Debugging #
+ instabilitiesLb[res,:] = 2 * (1/currPredNum) .* sum(theta2 .* (1 .- theta2), dims=2) # bStARS lower bound
+ netInstabilitiesLb[res,:] = currPredNum*(instabilitiesLb[res,:])
+ theta2mean = sum(theta2,dims=2)./currPredNum
+ instabilitiesUb[res,:] = 2 * theta2mean .* (1 .- theta2mean) # bStARS upper bound
+ netInstabilitiesUb[res,:] = currPredNum*instabilitiesUb[res,:]
+ end
+
+ totEdges = sum(totEdges)
+ netInstabilitiesLb = sum(netInstabilitiesLb, dims=1)[:]
+ netInstabilitiesUb = sum(netInstabilitiesUb, dims=1)[:]
+
+ for res = 1:totResponses
+ # take the supremum, find max Lambda, and set all smaller lambdas equal to that value
+ maxLb = findmax(instabilitiesLb[res,:])
+ maxLbInd = findall(x -> x == maxLb[1], instabilitiesLb[res,:])
+ maxLb = maxLb[1]
+ instabilitiesLb[res,maxLbInd[end]:end] .= maxLb
+ maxUb = findmax(instabilitiesUb[res,:])
+ maxUbInd = findall(x -> x == maxUb[1], instabilitiesUb[res,:])
+ maxUb = maxUb[1]
+ instabilitiesUb[res,maxUbInd[end]:end] .= maxUb
+ # find the minimum lambda for the gene, based on maximum for upper bound
+ # we are less interested in high instability lambdas, so okay to use
+ # upper bound
+ xx = findmin(abs.(instabilitiesLb[res,:] .- targetInstability))
+ #xx = findall(x -> x == xx[1],abs.(instabilitiesLb[res,:] .- targetMaxInstability))
+ xx = xx[2]
+ minLambdas[res] = (lambdaRange)[xx[end]] # to the right
+ # find the lambda nearest the min instability worth considering, use
+ # upper bound as that will be sure to find an lambda >= target instability lambda
+ xx = findmin(abs.(instabilitiesUb[res,:] .- targetInstability))
+ xx = findall(x -> x == xx[1], abs.(instabilitiesUb[res,:] .- targetInstability))
+ maxLambdas[res] = (lambdaRange)[xx[end]] # to the right
+ # note for typical bStARS, where you know what instability cutoff you
+ # want you'd use the upperbound to find the min lambda and the lb to
+ # find the max lambda
+ end
+
+ netInstabilitiesUb = netInstabilitiesUb ./ totEdges
+ netInstabilitiesLb = netInstabilitiesLb ./ totEdges
+ maxLb = findmax(netInstabilitiesLb)
+ maxLbInd = findall(x -> x == maxLb[1], netInstabilitiesLb)
+ netInstabilitiesLb[maxLbInd[end]:end] .= maxLb[1] # take supremum for lambdas smaller than instability max
+ maxUb = findmax(netInstabilitiesUb)
+ maxUbInd = (findall(x -> x == maxUb[1], netInstabilitiesUb))
+ netInstabilitiesUb[maxUbInd[end]:end] .= maxUb[1]
+ xx = findmin(abs.(netInstabilitiesLb .- targetInstability))
+ maxInstInd = findall(x -> x == xx[1], abs.(netInstabilitiesLb .- targetInstability))
+ minLambdaNet = (lambdaRange)[maxInstInd[end]]
+ xx = findmin(abs.(netInstabilitiesUb .- targetInstability))
+ minInstInd = findall(x -> x == xx[1], abs.(netInstabilitiesUb .- targetInstability))
+ maxLambdaNet = (lambdaRange)[minInstInd[end]]
+
+ grnData.minLambdaNet = minLambdaNet
+ grnData.maxLambdaNet = maxLambdaNet
+ grnData.maxLambdas = maxLambdas
+ grnData.minLambdas = minLambdas
+ grnData.netInstabilitiesLb = netInstabilitiesLb
+ grnData.netInstabilitiesUb = netInstabilitiesUb
+ grnData.instabilitiesLb = instabilitiesLb
+ grnData.instabilitiesUb = instabilitiesUb
+end
+
+function bstartsEstimateInstability(grnData::GrnData; totLambdas = 10, instabilityLevel = "Gene", zTarget = false, outputDir::Union{String, Nothing}=nothing)
+
+ totResponses,totSamps = size(grnData.responseMat) # totResponses is same as length(grnData["targGenes"])
+ totPreds = size(grnData.predictorMat,1)
+
+ # if instabilityLevel == "Gene"
+ # minLambda = minimum(grnData.minLambdas)
+ # maxLambda = maximum(grnData.maxLambdas)
+ # elseif instabilityLevel == "Network"
+ # minLambda = grnData.minLambdaNet
+ # maxLambda = grnData.maxLambdaNet
+ # else
+ # println("Use either 'Gene' or 'Network' for instabilityLevel")
+ # end
+
+ # Gene Level Instabilities
+ lambdaRangeGene = Vector{Vector{Float64}}(undef, totResponses)
+ for res in 1:totResponses
+ λmin = grnData.minLambdas[res]
+ λmax = grnData.maxLambdas[res]
+ lambdaRangeGene[res] = reverse(collect(range(λmin, stop=λmax, length=totLambdas)))
+ end
+ grnData.lambdaRangeGene = lambdaRangeGene
+
+ # Network Level Instability
+ minLambda = grnData.minLambdaNet
+ maxLambda = grnData.maxLambdaNet
+
+ lambdaRange = collect(range(minLambda, stop = maxLambda, length = totLambdas))
+ lambdaRange = reverse(lambdaRange)
+ grnData.lambdaRange = lambdaRange
+
+ geneInstabilities = zeros(totResponses,totLambdas)
+ netInstabilities = zeros(totLambdas,1)
+ totEdges = 0 # denominator for network Instabilities
+ # store number of subsamples for which an edge was nonzero, given that some
+ # prior weights can be set to infinity, track to make sure these edges are
+ # not counted
+ ssMatrix = Inf*ones(totLambdas,totResponses,totPreds)
+ subsamps = grnData.subsamps
+ totSS = size(subsamps)[1]
+
+ # get (finite) predictor indices for each response
+ responsePredInds = Vector{Vector{Int}}(undef,0)
+ for res = 1:totResponses
+ currWeights = grnData.penaltyMat[res,:]
+ push!(responsePredInds,findall(x -> x!=Inf, currWeights))
+ end
+
+ totEdges = zeros(totResponses)
+ betas = Array{Float64,3}(undef, totResponses, totPreds, totLambdas)
+ Threads.@threads for res in ProgressBar(1:totResponses)
+ lambdaRange = instabilityLevel == "Network" ?
+ grnData.lambdaRange : # single shared vector
+ grnData.lambdaRangeGene[res] # per‑gene vector
+ predInds = responsePredInds[res]
+ currPredNum = length(predInds)
+ totEdges[res] = currPredNum
+ penaltyFactor = grnData.penaltyMat[res,predInds]
+ ssVals = zeros(totLambdas,currPredNum)
+ for ss = 1:totSS
+ subsamp = subsamps[ss,:]
+ dt = fit(ZScoreTransform, grnData.predictorMat[predInds, subsamp], dims=2)
+ currPreds = transpose(StatsBase.transform(dt, grnData.predictorMat[predInds, subsamp]))
+ if zTarget
+ dt = fit(ZScoreTransform, grnData.responseMat[res, subsamp], dims=1)
+ currResponses = StatsBase.transform(dt, grnData.responseMat[res, subsamp])
+ else
+ currResponses = grnData.responseMat[res, subsamp]
+ end
+ lsoln = glmnet(currPreds, currResponses, penalty_factor = penaltyFactor, lambda = lambdaRange, alpha = 1.0)
+ currBetas = lsoln.betas
+ betas[res,predInds, :] = currBetas
+ ssVals = ssVals + abs.(sign.(currBetas))'
+ end
+ ssMatrix[:,res,predInds] = ssVals
+ end
+
+ for res = 1:totResponses
+ currWeights = grnData.penaltyMat[res,:]
+ predInds = responsePredInds[res]
+ currPredNum = length(predInds)
+ ssVals = zeros(totLambdas,currPredNum)
+ ssVals[:,:] = ssMatrix[:,res,predInds]
+ theta2 = (1/totSS)*ssVals # empirical edge probability, lambdas X currPreds
+ instabilitiesPerEdge = 2*(theta2 .* (1 .- theta2))
+ aveInstabilities = mean(instabilitiesPerEdge,dims=2) # lambdas X 1
+ maxUb = (findmax(aveInstabilities))
+ instabSUP = aveInstabilities
+ maxUbInd = first(Tuple(maxUb[2]))
+ instabSUP[maxUbInd:end,1] .= maxUb[1]
+ geneInstabilities[res,:] = instabSUP
+ end
+
+ ## calculate instabilities network-wise
+ currSS = zeros(totResponses,totPreds)
+ instabRange = zeros(totLambdas,1)
+ for lind = 1:totLambdas # start at highest lambda (lowest instability and work down)
+ currSS[:,:] = ssMatrix[lind,:,:]
+ theta2 = (1/totSS)*currSS # empirical edge probability, responses X currPreds
+ instabilitiesPerEdge = 2*(theta2 .* (1 .- theta2))
+ instabVec = instabilitiesPerEdge[:]
+ validEdges = findall(isfinite.(currSS[:])) # limit to finite edges
+ instabMax = findmax(instabRange)[1]
+ netInstabilities[lind] = max(mean(instabVec[validEdges]),max(instabMax))
+ end
+ grnData.netInstabilities = vec(netInstabilities)
+ grnData.geneInstabilities = geneInstabilities
+ grnData.stabilityMat = ssMatrix
+ grnData.betas = betas
+
+ if outputDir !== nothing && outputDir !== ""
+ outputFile = joinpath(outputDir, "instabOutMat.jld")
+ save_object(outputFile, grnData)
+ end
+end
diff --git a/src/grn/RefineTFA.jl b/src/grn/RefineTFA.jl
new file mode 100755
index 0000000..611a935
--- /dev/null
+++ b/src/grn/RefineTFA.jl
@@ -0,0 +1,94 @@
+"""
+GRN
+
+Main function:
+- `refineTFA`: Combines GRNs using merged TFs, TFA data, and other metadata.
+
+Dependencies:
+- `Data.GeneExpressionData` and `Data.PriorTFAData`
+- `Prior.mergeDegenerateTFs`
+"""
+ # Access dependencies from the package; no need to include files again
+ # using ..Data # For GeneExpressionData
+ # using ..PriorTFA # For PriorTFAData
+ # using ..MergeDegenerate
+
+ # # Other standard packages
+ # using PyPlot
+ # using Statistics
+ # using CSV
+ # using DelimitedFiles
+ # using JLD2
+ # using NamedArrays
+
+ # Export only the function users need
+ # export combineGRNS2
+
+function refineTFA(data::GeneExpressionData, mergedTFsData::mergedTFsResult;
+ priorFile::String = "",
+ tfaGeneFile::String = "",
+ edgeSS::Int = 0,
+ minTargets::Int = 3,
+ zTarget::Bool = true,
+ geneExprFile::String = "",
+ targFile::String = "",
+ regFile::String = "",
+ outputDir::Union{String, Nothing} = nothing)
+ if !isnothing(outputDir)
+ mkpath(outputDir)
+ end
+
+ # Ensure required expression data are available and non-empty
+ requiredFields = [
+ :tfaGenes,
+ :tfaGeneMat,
+ :potRegs,
+ :potRegsmRNA,
+ :potRegMatmRNA
+ ]
+
+ missingFields = [field for field in requiredFields if isempty(getfield(data, field))]
+
+ if !isempty(missingFields)
+ println("Missing or empty fields in data: ", missingFields)
+ println("Generating required data by loading from input files...")
+
+ requiredFiles = [
+ ("Gene Expression File", geneExprFile),
+ ("Target Gene File", targFile),
+ ("Potential Regulators File", regFile),
+ ]
+
+ missingFiles = [name for (name, path) in requiredFiles if isempty(path) || !isfile(path)]
+ if !isempty(missingFiles)
+ error("Cannot generate data. Missing input files: ", missingFiles)
+ end
+
+ data = GeneExpressionData()
+ loadExpressionData!(data, geneExprFile)
+ loadAndFilterTargetGenes!(data, targFile; epsilon=0.01)
+ loadPotentialRegulators!(data, regFile)
+ processTFAGenes!(data, tfaGeneFile; outputDir = outputDir)
+ end
+
+
+ # 2. Integrate prior information for TFA estimation
+ tfaData = PriorTFAData()
+ processPriorFile!(tfaData, data, priorFile; mergedTFsData, minTargets = minTargets);
+ calculateTFA!(tfaData, data; edgeSS = edgeSS, zTarget = zTarget, outputDir = outputDir);
+
+ # Save TFA as a text file# Save median TFA if outputDir is specified
+ if !isnothing(outputDir)
+ namedMedTFA = NamedArray(tfaData.medTfas)
+ setnames!(namedMedTFA, tfaData.pRegs, 1)
+ setnames!(namedMedTFA, data.cellLabels, 2)
+
+ outputfile = joinpath(outputDir, "TFA.txt")
+ open(outputfile, "w") do io
+ writedlm(io, permutedims(data.cellLabels))
+ writedlm(io, namedMedTFA)
+ end
+ end
+
+ return tfaData
+end
diff --git a/src/grn/UtilsGRN.jl b/src/grn/UtilsGRN.jl
new file mode 100755
index 0000000..548ef29
--- /dev/null
+++ b/src/grn/UtilsGRN.jl
@@ -0,0 +1,54 @@
+function firstNByGroup(vect::AbstractVector, N::Integer)
+ """
+ firstNByGroupIndices(vec::AbstractVector, N::Integer)
+
+ Return the indices of the first `N` occurrences of each unique element in `vec`.
+
+ This is useful when you want to subset another array based on limited occurrences per group.
+
+ # Arguments
+ - `vec`: A vector of values to group by (e.g., transcription factors).
+ - `N`: The maximum number of elements to select per unique group.
+
+ # Returns
+ - A vector of indices corresponding to the first `N` entries per group.
+ """
+ seen = Dict{eltype(vect), Int}()
+ idxs = Int[]
+ for (i, v) in enumerate(vect)
+ seen[v] = get(seen, v, 0) + 1
+ if seen[v] <= N
+ push!(idxs, i)
+ end
+ end
+ return idxs
+end
+
+
+
+# function firstNByGroup(vect::AbstractVector, N::Integer)
+# """
+# firstNByGroup(vect::AbstractVector, N::Integer)
+
+# Return a boolean mask selecting the first `N` occurrences of each unique element in `vec`.
+
+# This is useful for logical indexing when subsetting arrays with the same length as `vec`.
+
+# # Arguments
+# - `vect`: A vector of values to group by.
+# - `N`: The maximum number of elements to select per group.
+
+# # Returns
+# - A boolean mask vector with `true` at positions of the first `N` occurrences per group.
+# """
+# seen = Dict{eltype(vect), Int}()
+# mask = falses(length(vect))
+# for i in eachindex(vect)
+# seen[vect[i]] = get(seen, vect[i], 0) + 1
+# mask[i] = seen[vect[i]] <= N
+# end
+# return mask
+# end
+
+
+
diff --git a/src/metrics/CalcPR.jl b/src/metrics/CalcPR.jl
new file mode 100755
index 0000000..5a904db
--- /dev/null
+++ b/src/metrics/CalcPR.jl
@@ -0,0 +1,658 @@
+"""
+ computeMacroMetrics(gsPrecisionsByTf, gsRecallsByTf, gsFprsByTf, gsAuprsByTf, gsArocsByTf;
+ target_points=1000, min_step=1e-4, step_method=:min_gap)
+
+Compute macro-averaged PR and ROC curves across all TFs using interpolation.
+
+Per-TF curves are interpolated onto a common grid and averaged. Only TFs with
+non-empty, non-zero recall arrays contribute to the average.
+
+# Arguments
+- `gsPrecisionsByTf` : Vector of per-TF precision arrays.
+- `gsRecallsByTf` : Vector of per-TF recall arrays.
+- `gsFprsByTf` : Vector of per-TF false positive rate arrays.
+- `gsAuprsByTf` : Vector of per-TF AUPR values.
+- `gsArocsByTf` : Vector of per-TF AUROC values.
+
+# Keyword Arguments
+- `target_points::Int=1000` : Number of interpolation points if `step_method=:target_points`.
+- `min_step::Float64=1e-4` : Minimum allowable step size for interpolation.
+- `step_method::Symbol=:min_gap` : Step selection method — `:min_gap` or `:target_points`.
+
+# Returns
+An `OrderedDict` with two nested `OrderedDict`s:
+- `:macroPR` → `:auprInterpolated`, `:precisions`, `:recalls`
+- `:macroROC` → `:aurocInterpolated`, `:fprs`, `:tprs`
+
+# Example
+```julia
+results = computeMacroMetrics(precsByTf, recsByTf, fprsByTf, auprsByTf, aurocsByTf)
+results[:macroPR][:auprInterpolated] # macro AUPR
+results[:macroPR][:precisions] # macro precision curve
+results[:macroROC][:tprs] # macro TPR curve
+```
+"""
+function computeMacroMetrics(gsPrecisionsByTf, gsRecallsByTf, gsFprsByTf, gsAuprsByTf, gsArocsByTf;
+ target_points=1000,
+ min_step=1e-4,
+ step_method=:min_gap)
+
+ # --- Scalar Macro Metrics ---
+ macroAUPR_scalar = mean(filter(isfinite, gsAuprsByTf))
+ macroAUROC_scalar = mean(filter(isfinite, gsArocsByTf))
+
+
+ # --- Interpolated Macro Curves ---
+
+ # Determine dynamic max recall/FPR
+ allRecalls = reduce(vcat, filter(!isempty, gsRecallsByTf))
+ allRecalls = filter(x -> isfinite(x), allRecalls)
+ allFprs = reduce(vcat, filter(!isempty, gsFprsByTf))
+ allFprs = filter(x -> isfinite(x), allFprs)
+
+ maxRecall = isempty(allRecalls) ? 0.0 : maximum(allRecalls)
+ maxFpr = isempty(allFprs) ? 0.0 : maximum(allFprs)
+
+ recStep = adaptiveStep(gsRecallsByTf; target_points=target_points, min_step=min_step, method=step_method)
+ fprStep = adaptiveStep(gsFprsByTf; target_points=target_points, min_step=min_step, method=step_method)
+ recInterpPts = 0.0:recStep:maxRecall # Interpolation points
+ fprInterpPts = 0.0:fprStep:maxFpr
+
+ macroPrecisions = zeros(length(recInterpPts))
+ macroTprs = zeros(length(fprInterpPts))
+ nTFs = length(gsPrecisionsByTf)
+ validTFs = [i for i in 1:nTFs if !isempty(gsRecallsByTf[i]) && maximum(gsRecallsByTf[i]) > 0]
+ for i in validTFs
+ # PR Curve
+ interpPrec = dedupNInterpolate(gsRecallsByTf[i], gsPrecisionsByTf[i], recInterpPts)
+ macroPrecisions .+= interpPrec
+ # ROC Curve # This is needed for ROC curves. Here Tprs != Recall but interpolated Recall on Fprs
+ interpTpr = dedupNInterpolate(gsFprsByTf[i], gsRecallsByTf[i], fprInterpPts)
+ macroTprs .+= interpTpr
+ end
+
+ macroPrecisions ./= length(validTFs)
+ macroTprs ./= length(validTFs)
+ macroAUPR_interpolated = sum(0.5 .* (macroPrecisions[2:end] + macroPrecisions[1:end-1]) .* diff(recInterpPts))
+ macroAUROC_interpolated = sum(0.5 .* (macroTprs[2:end] + macroTprs[1:end-1]) .* diff(fprInterpPts))
+
+ macroResults = OrderedDict(
+ :macroPR => OrderedDict(
+ # :auprScalar => macroAUPR_scalar,
+ :auprInterpolated => macroAUPR_interpolated,
+ :precisions => macroPrecisions,
+ :recalls => collect(recInterpPts)
+ ),
+ :macroROC => OrderedDict(
+ # :aurocScalar => macroAUROC_scalar,
+ :aurocInterpolated => macroAUROC_interpolated,
+ :fprs => collect(fprInterpPts),
+ :tprs => macroTprs
+ )
+)
+ return macroResults
+end
+
+"""
+ evaluatePerTF(regs, targs, rankings, infEdges, gsEdgesByTf, uniGsRegs, gsRandPRbyTf;
+ saveDir=nothing, breakTies=true, target_points=1000, min_step=1e-4,
+ step_method=:min_gap, xLimitRecall=1.0)
+
+Compute per-TF precision-recall and ROC metrics for an inferred gene regulatory network (GRN)
+against a gold standard, and optionally plot per-TF PR/ROC curves.
+
+# Arguments
+- `regs::Vector{<:AbstractString}` : Regulators for each inferred edge.
+- `targs::Vector{<:AbstractString}` : Target genes for each inferred edge.
+- `rankings::Vector{Float64}` : Confidence scores for inferred edges.
+- `infEdges::Vector{String}` : Inferred edges as `"TF,Target"` strings.
+- `gsEdgesByTf::Vector{Array{String}}` : Gold standard edges grouped by TF.
+- `uniGsRegs::Vector{<:AbstractString}` : Unique TFs in the gold standard.
+- `gsRandPRbyTf::Vector{Float64}` : Random baseline AUPR per TF.
+
+# Keyword Arguments
+- `saveDir::Union{String,Nothing}=nothing` : Directory to save per-TF plots. If `nothing`, plots are skipped.
+- `breakTies::Bool=true` : Average indicators over tied rankings to smooth PR curves.
+- `target_points::Int=1000` : Interpolation points for macro curves.
+- `min_step::Float64=1e-4` : Minimum step size for interpolation.
+- `step_method::Symbol=:min_gap` : Step selection method for macro metrics.
+- `xLimitRecall::Float64=1.0` : : Maximum recall shown on the x-axis of diagnostic per-TF PR plots.
+
+# Returns
+An `OrderedDict` with:
+- `:perTF` → `OrderedDict` with `:gsRegs`, `:tfEdges`, `:tfEdgesVec`, `:precisions`,
+ `:recalls`, `:fprs`, `:randPR`, `:auprs`, `:arocs`
+- `:macroPR` → `OrderedDict` with `:auprInterpolated`, `:precisions`, `:recalls`
+- `:macroROC` → `OrderedDict` with `:aurocInterpolated`, `:fprs`, `:tprs`
+
+# Example
+```julia
+results = evaluatePerTF(regs, targs, rankings, infEdges, totTargGenes,
+ gsEdgesByTf, uniGsRegs, gsRandPRbyTf)
+results[:perTF][:auprs] # per-TF AUPR values
+results[:macroPR][:auprInterpolated] # macro AUPR
+```
+"""
+
+function evaluatePerTF(
+ regs::Vector{<:AbstractString}, targs::Vector{<:AbstractString}, rankings::Vector{Float64}, infEdges::Vector{String},
+ gsEdgesByTf::Vector{Array{String}}, uniGsRegs::Vector{<:AbstractString}, gsRandPRbyTf::Vector{Float64};
+ saveDir::Union{AbstractString, Nothing} = nothing, breakTies::Bool=true, target_points=1000, min_step=1e-4, step_method=:min_gap, xLimitRecall::Float64 = 1.0)
+
+
+ if saveDir !== nothing && !isempty(saveDir)
+ saveDir = joinpath(saveDir, "perTF")
+ mkpath(saveDir)
+ end
+
+ totGsRegs = length(uniGsRegs)
+ totTargGenes = maximum(length.(gsEdgesByTf)) # approximate max targets
+
+ println("---- Computing Per-TF Metrics")
+ gsAuprsByTf = zeros(totGsRegs,1)
+ gsArocsByTf = zeros(totGsRegs,1)
+ gsPrecisionsByTf = Array{Float64}[]
+ gsRecallsByTf = Array{Float64}[]
+ gsFprsByTf = Array{Float64}[]
+
+ allTfEdgesVec = Array{Float64}[] # one vector per TF
+ allTfEdges = Vector{Vector{String}}() # raw edges
+ allTfRanks = Array{Float64}[] # one array per TF
+ for (iTF, tf) in enumerate(uniGsRegs)
+ # Filter inferred edges for this TF
+ inds = findall(x -> x == tf, regs)
+ tfEdges = infEdges[inds]
+ tfRanks = rankings[inds]
+ tfEdgesVec = [in(edge, gsEdgesByTf[iTF]) ? 1 : 0 for edge in tfEdges]
+ # Tie-break if needed
+ if breakTies
+ tfAbsRanks = abs.(tfRanks)
+ uniqueRanks = unique(tfAbsRanks)
+ meanIndicator = Dict{Float64, Float64}()
+ for r in uniqueRanks
+ idxs = findall(x -> x == r, tfAbsRanks)
+ meanIndicator[r] = mean(tfEdgesVec[idxs])
+ end
+ tfEdgesVec = [meanIndicator[r] for r in tfAbsRanks]
+ end
+ # Save for later
+ push!(allTfEdgesVec, tfEdgesVec)
+ push!(allTfEdges, tfEdges)
+ push!(allTfRanks, tfRanks) # store adjusted ranks for this TF
+
+ # Compute cumulative metrics
+ tfTP = 0.0
+ tfPrec = Float64[]
+ tfRec = Float64[]
+ tfFpr = Float64[]
+ totNegativesTF = totTargGenes - length(gsEdgesByTf[iTF])
+ for j in 1:length(tfEdgesVec)
+ tfTP += tfEdgesVec[j]
+ fp = j - tfTP
+ push!(tfPrec, tfTP / j)
+ push!(tfRec, tfTP / length(gsEdgesByTf[iTF]))
+ push!(tfFpr, fp / totNegativesTF)
+ end
+ # Prepend starting point
+ if !isempty(tfPrec)
+ tfRec = vcat(0.0, tfRec)
+ tfPrec = vcat(tfPrec[1], tfPrec)
+ tfFpr = vcat(0.0, tfFpr)
+ end
+
+ # -- Plot perTF metric
+ currRandPR = gsRandPRbyTf[iTF]
+ plotPRCurve(tfRec, tfPrec, currRandPR, saveDir; xLimitRecall=xLimitRecall, baseName = tf)
+ plotROCCurve(tfFpr, tfRec, saveDir; baseName = tf)
+
+ # Save per-TF metrics
+ push!(gsPrecisionsByTf, tfPrec)
+ push!(gsRecallsByTf, tfRec)
+ push!(gsFprsByTf, tfFpr)
+ # AUPR / AUROC per TF
+ heightsTF = (tfPrec[2:end] + tfPrec[1:end-1]) / 2
+ widthsTF = tfRec[2:end] - tfRec[1:end-1]
+ gsAuprsByTf[iTF] = sum(heightsTF .* widthsTF)
+
+ widthsRocTF = tfFpr[2:end] - tfFpr[1:end-1]
+ heightsRocTF = (tfRec[2:end] + tfRec[1:end-1]) / 2
+ gsArocsByTf[iTF] = sum(widthsRocTF .* heightsRocTF)
+ end
+
+ # Compute macro metrics
+
+ # maxLen = maximum(length.(gsPrecisionsByTf))
+ # macroPrecisions = zeros(maxLen)
+ # macroRecalls = zeros(maxLen)
+ # for i in 1:maxLen
+ # precVals = Float64[]
+ # recVals = Float64[]
+ # for (p, r) in zip(gsPrecisionsByTf, gsRecallsByTf)
+ # if i <= length(p)
+ # push!(precVals, p[i])
+ # push!(recVals, r[i])
+ # end
+ # end
+ # macroPrecisions[i] = mean(precVals)
+ # macroRecalls[i] = mean(recVals)
+ # end
+ # macroAUPR = mean(gsAuprsByTf)
+ # macroAUROC = mean(gsArocsByTf)
+
+ # A global result but requires perTF to compute
+ macroResults = computeMacroMetrics(gsPrecisionsByTf, gsRecallsByTf, gsFprsByTf, gsAuprsByTf, gsArocsByTf;
+ target_points=target_points, min_step=min_step, step_method=step_method)
+
+ perTFDict = OrderedDict(
+ :gsRegs => uniGsRegs,
+ :tfEdges => allTfEdges, # raw inferred edges per TF
+ :tfEdgesVec => allTfEdgesVec, # 0/1 indicator per TF
+ :precisions => gsPrecisionsByTf,
+ :recalls => gsRecallsByTf,
+ :fprs => gsFprsByTf,
+ :randPR => gsRandPRbyTf,
+ :auprs => gsAuprsByTf,
+ :arocs => gsArocsByTf
+ )
+
+ # merge!(perTF, macroResults)
+ results = OrderedDict(
+ :perTF => perTFDict,
+ :macroPR => macroResults[:macroPR],
+ :macroROC => macroResults[:macroROC]
+ )
+ return results
+end
+
+"""
+ computePR(
+ gsFile::String, infTrnFile::String;
+ gsRegsFile::Union{String, Nothing} = nothing, targGeneFile::Union{String, Nothing} = nothing, # filtering
+ breakTies::Bool = true, partialAUPRlimit::Float64 = 0.1, doPerTF::Bool = true, # computation
+ xLimitRecall::Float64 = 1.0, saveDir::Union{String, Nothing} = nothing, # plotting
+ target_points::Int = 1000, min_step::Float64 = 1e-4, step_method::Symbol = :min_gap) # interpolation
+
+Compute precision-recall (PR) and ROC metrics for an inferred gene regulatory network (GRN)
+against a gold standard, including optional per-TF and macro-level metrics.
+Results are saved to a `.jld` file and PR/ROC curves are plotted automatically.
+
+# Arguments
+- `gsFile::String` : Gold standard TSV file with columns in order: TF (1), Target (2), Weight (3) — column names can be anything
+- `infTrnFile::String` : Inferred GRN TSV file with columns in order: TF (1), Target (2), Weight (3) — column names can be anything
+
+# Keyword Arguments
+- `gsRegsFile::Union{String,Nothing}=nothing` : File listing regulators to include. Default uses all GS regulators.
+- `targGeneFile::Union{String,Nothing}=nothing` : File listing target genes to include. Default uses all GS targets.
+- `rankColTrn::Int=3` : Column index for scores in the inferred GRN file.
+- `breakTies::Bool=true` : Average indicators over tied rankings to smooth PR curves.
+- `partialAUPRlimit::Float64=0.1` : Maximum recall cutoff for computing partial AUPR.
+- `xLimitRecall::Float64=1.0` : Maximum recall shown on the x-axis of diagnostic PR plots.
+- `doPerTF::Bool=true` : Whether to compute per-TF and macro metrics.
+- `saveDir::Union{String,Nothing}=nothing` : Directory for saving plots and results. Defaults to a timestamped folder.
+- `target_points::Int=1000` : Interpolation points for macro curves.
+- `min_step::Float64=1e-4` : Minimum step size for interpolation.
+- `step_method::Symbol=:min_gap` : Step selection method for macro metrics.
+
+# Returns
+An `OrderedDict` with:
+- `:gsRegs`, `:gsTargs`, `:gsEdges` → Gold standard regulators, targets, and edges.
+- `:randPR` → Random baseline PR value.
+- `:infRegs`, `:infEdges` → Inferred GRN regulators and edges.
+- `:edgesVec` → Tie-adjusted or binary indicator vector.
+- `:precisions`, `:recalls`, `:fprs` → Overall PR/ROC curve arrays.
+- `:auprs` → Dict with `:full` AUPR and `:partial` AUPR (up to `partialAUPRlimit`).
+- `:arocs` → Overall AUROC.
+- `:f1scores` → F1-score array.
+- `:perTF` → Per-TF metrics dict (see `evaluatePerTF`). `nothing` if `doPerTF=false`.
+- `:macroPR` → Macro PR dict. `nothing` if `doPerTF=false`.
+- `:macroROC` → Macro ROC dict. `nothing` if `doPerTF=false`.
+- `:savedFile` → Path to saved `.jld` results file.
+
+# Example
+```julia
+results = computePR("gs.tsv", "inferredGRN.tsv";
+ breakTies=true, partialAUPRlimit=0.1, saveDir="plots")
+
+results[:auprs][:full] # overall AUPR
+results[:auprs][:partial][:value] # partial AUPR at recall limit
+results[:perTF][:auprs] # per-TF AUPRs
+results[:macroPR][:auprInterpolated] # macro AUPR
+```
+"""
+# function computePR(
+# gsFile::String, infTrnFile::String;
+# gsRegsFile::Union{String, Nothing} = nothing, targGeneFile::Union{String, Nothing} = nothing,
+# # rankColTrn::Int = 3,
+# breakTies::Bool = true, partialLimitRecall::Float64 = 0.1, doPerTF::Bool = true,
+# saveDir::Union{String, Nothing} = nothing, target_points::Int = 1000, min_step::Float64 = 1e-4,
+# step_method::Symbol = :min_gap)
+
+infTrnFile = "/data/miraldiNB/anthony/Inferelator_Julia/outputs/251125_HAE_ISGF3_GAS_combined/combined/combined_sp.tsv"
+gsFile = "/data/miraldiNB/giulia/GS_10/IFNB_A549.tsv"
+targGeneFile = "/data/miraldiNB/giulia/intersection/IFNB_A549_target_genes.txt"
+gsRegsFile = "/data/miraldiNB/Katko/Projects/Julia/AnthonyData/tfs.txt"
+
+function computePR(
+ gsFile::String, infTrnFile::String;
+ gsRegsFile::Union{String, Nothing} = nothing, targGeneFile::Union{String, Nothing} = nothing, # filtering
+ breakTies::Bool = true, partialAUPRlimit::Float64 = 0.1, doPerTF::Bool = true, # computation
+ xLimitRecall::Float64 = 1.0, saveDir::Union{String, Nothing} = nothing, # plotting
+ target_points::Int = 1000, min_step::Float64 = 1e-4, step_method::Symbol = :min_gap) # interpolation
+
+ # Helper function for empty/fallback PR result
+ function emptyPRResult(;
+ gsRegs::Vector{String} = String[],
+ gsTargs::Vector{String} = String[],
+ gsEdges::Vector{String} = String[],
+ infRegs::Vector{String} = String[],
+ infEdges::Vector{String} = String[])
+ return OrderedDict(
+ :gsRegs => gsRegs,
+ :gsTargs => gsTargs,
+ :gsEdges => gsEdges,
+ :infRegs => infRegs,
+ :infEdges => infEdges,
+ :precisions => Float64[],
+ :recalls => Float64[],
+ :fprs => Float64[],
+ :auprs => Dict(
+ :full => 0.0,
+ :partial => Dict(
+ :value => 0.0,
+ :recallLimit => partialAUPRlimit # captured from outer scope
+ )
+ ),
+ :arocs => 0.0,
+ :f1scores => Float64[],
+ :perTF => nothing,
+ :macroPR => nothing,
+ :macroROC => nothing,
+ :savedFile => nothing
+ )
+ end
+ # ----- Part 1. Load and Validate Input Files
+ gsData = CSV.File(gsFile; delim="\t", header=true)
+ trnData = CSV.File(infTrnFile; delim="\t", header=true)
+
+ validateColumnCount(gsData, gsFile)
+ validateColumnCount(trnData, infTrnFile)
+
+ # Reload selecting only first 3 columns by position
+ gsData = CSV.File(gsFile; delim="\t", select=[1, 2, 3])
+ trnData = CSV.File(infTrnFile; delim="\t", select=[1, 2, 3])
+
+
+ # Initialize defaults BEFORE any conditionals
+ potTargGenes = String[]
+ potTargGenesSet = Set{String}()
+ gsPotRegs = String[]
+ gsPotRegsSet = Set{String}()
+
+ # ----- Part 2. Load Optional Filter Files (Potential Targets and Regulators)
+ if targGeneFile !== nothing && !isempty(targGeneFile)
+ potTargGenes = readlines(targGeneFile)
+ potTargGenesSet = Set(potTargGenes)
+ totTargGenes = length(unique(potTargGenes))
+ @info "Target gene file loaded: $(length(potTargGenes)) genes from $targGeneFile"
+ else
+ @info "No target gene file provided — will use gold standard targets after GS is loaded"
+ end
+
+ # Load regulator list if provided
+ if gsRegsFile !== nothing && !isempty(gsRegsFile)
+ gsPotRegs = readlines(gsRegsFile)
+ gsPotRegsSet = Set(gsPotRegs)
+ @info "Regulator file loaded: $(length(gsPotRegs)) regulators from $gsRegsFile"
+ else
+ @info "No regulator file provided — will use all regulators in gold standard"
+ end
+
+ # ----- Part 3. Filter and Process Gold Standard
+
+ # Filter gold standard by regulators and target genes if specified: limit to TF-gene interactions considered by the model
+ gsFilteredData = filter(row -> row[3] > 0, gsData)
+ # if gsRegsFile !== nothing && !isempty(gsRegsFile)
+ # # gsFilteredData = filter(row -> (row.TF in gsPotRegs) && (row.Target in potTargGenes), gsFilteredData)
+ # gsFilteredData = filter(row ->(row[1] in gsPotRegsSet) && (row[2] in potTargGenesSet), gsFilteredData)
+ # size(gsFilteredData)
+ # end
+
+ # Filter by regulators if provided
+ if !isempty(gsPotRegsSet)
+ gsFilteredData = filter(row -> row[1] in gsPotRegsSet, gsFilteredData)
+ end
+
+ # Filter by targets if provided
+ if !isempty(potTargGenesSet)
+ gsFilteredData = filter(row -> row[2] in potTargGenesSet, gsFilteredData)
+ end
+
+ # Edge-case check: skip if no gold standard edges
+ if isempty(gsFilteredData)
+ @warn "No overlapping gold standard edges after filtering. Skipping PR/ROC evaluation."
+ return emptyPRResult()
+ end
+
+ # Define gold standard edges (each as "TF,Target")
+ gsRegs = collect(row[1] for row in gsFilteredData)
+ gsTargs = collect(row[2] for row in gsFilteredData)
+ totGsInts = size(gsFilteredData)[1]
+ uniGsRegs = unique(gsRegs) # unique regulators or TFs in GS
+ totGsRegs = length(uniGsRegs)
+ uniGsTargs = unique(gsTargs) # unique targets in GS
+ totGsTargs = length(uniGsTargs)
+ # Create new edges vector. (Each edge is a tuple (TF, gene))
+ gsEdges = [string(gsRegs[i], ",", gsTargs[i]) for i in 1:length(gsRegs)]
+
+ # If targGeneFile wasn’t provided, use GS targets.
+ if targGeneFile === nothing || isempty(targGeneFile)
+ potTargGenes = uniGsTargs
+ totTargGenes = length(potTargGenes)
+ end
+
+ # Compute evaluation universe and random PR baseline
+ gsTotPotInts = totTargGenes*totGsRegs # complete universe size
+ # gsTotPotInts = totGsRegs * totGsTargs
+ gsRandPR = totGsInts/gsTotPotInts
+
+ # group gold standard edges by regulator
+ gsEdgesByTf = Array{String}[]
+ gsRandPRbyTf = zeros(totGsRegs)
+ for gind = 1:totGsRegs
+ currInds = findall(x -> x == uniGsRegs[gind], gsRegs)
+ push!(gsEdgesByTf, vec(permutedims(gsEdges[currInds])))
+ gsRandPRbyTf[gind] = length(gsEdgesByTf[gind])/totGsTargs
+ end
+
+ # ----- Part 4. Load and Filter Inferred GRN
+ @info "Loading and Processing Inferred GRN"
+ # Filter "inferred GRN" to include only 'TFs' in GS and 'TargetGenes' in 'potTargGenes'
+ grnData = filter(row -> (row[1] in uniGsRegs) && (row[2] in potTargGenes), trnData)
+ grnData = collect(grnData)
+
+ if isempty(grnData)
+ @warn "No inferred edges after filtering. Skipping evaluation."
+ return emptyPRResult(gsRegs=uniGsRegs, gsTargs=uniGsTargs, gsEdges=gsEdges)
+ end
+
+ # Order by absolute value of weights/confidences
+ grnData = sort(grnData, by = row -> abs(row[3]), rev = true)
+ regs = collect(row[1] for row in grnData)
+ targs = collect(row[2] for row in grnData)
+ rankings = collect(row[3] for row in grnData)
+ absRankings = abs.(rankings)
+ # Create inferred edges vector. (Each edge is a tuple (TF, gene))
+ infEdges = [string(regs[i], ",", targs[i]) for i in 1:length(regs)]
+ totTrnInts = length(infEdges)
+
+ # ----- Part 5. Compute Edge Indicators
+ @info "Computing Edge Indicators and Labels"
+ # create binary labels. 1 if infEgde in gsEdge and 0 otherwise
+ commonEdgesBinaryVec = [in(edge, gsEdges) ? 1 : 0 for edge in infEdges]
+ @info "Total Interactions w/ GS TFs ($(length(unique(regs)))): $totTrnInts"
+
+ if breakTies # If breaking ties is desired, compute tie-adjusted (mean) vector.
+ # Break ties in weights/confidences to smooth out abrupt jumps caused by abitrary ordering of tied predictions
+ # Create a dictionary mapping each score to the mean value of commonEdgesVec for tied predictions
+ @info "Tie breaker enabled"
+ uniqueRankings = unique(absRankings)
+ meanIndicator = Dict{Float64, Float64}()
+ for currRank in uniqueRankings
+ inds = findall(x -> x == currRank, absRankings)
+ meanIndicator[currRank] = mean(commonEdgesBinaryVec[inds])
+ end
+ meanEdgesVec = [meanIndicator[ix] for ix in absRankings]
+ edgesVec = meanEdgesVec
+ else # If not breaking ties, use binary vector
+ edgesVec = commonEdgesBinaryVec
+ end
+
+ # ----- Part 6. Compute Overall PR/ROC Performance Metrics
+ @info "Computing Performance Metrics"
+ totalNegatives = gsTotPotInts - totGsInts # Total posisble interactions - length(gsEdges) = TN + FP
+ gsPrecisions = zeros(totTrnInts)
+ gsRecalls = zeros(totTrnInts)
+ gsFprs = zeros(totTrnInts)
+
+ cummulativeTP = 0.0 # can be fractional in tie-adjusted mode
+ for idx in 1:totTrnInts
+ cummulativeTP += edgesVec[idx] # Add the effective contirbution for this prediction. This is also True Positive
+ # False positive (FP). This is a weighted FP in the case of tie breaking
+ falsePositives = idx - cummulativeTP
+
+ # Compute precision: TP divided by total prediction so far
+ # idx is always the total predicted positive at any point j
+ # such that j = TP + FP
+ gsPrecisions[idx] = cummulativeTP / idx #
+ gsRecalls[idx] = cummulativeTP / totGsInts # truePositives/length(gsEdges) == tp/tp+fn
+ gsFprs[idx] = falsePositives / totalNegatives
+ end
+
+ # Prepend starting point for plotting.
+ if !isempty(gsPrecisions)
+ gsRecalls = vcat(0.0, gsRecalls)
+ gsPrecisions = vcat(gsPrecisions[1], gsPrecisions)
+ gsFprs = vcat(0.0, gsFprs)
+ end
+ # Compute F1-scores
+ # gsF1scores = 2 * (gsPrecisions .* gsRecalls) ./ (gsPrecisions + gsRecalls);
+ gsF1scores = [p + r > 0 ? 2p*r/(p+r) : 0.0 for (p, r) in zip(gsPrecisions, gsRecalls)];
+
+ # ----- Part 7. Compute AUPR and AUROC
+ @info "Computing AUPR and AUROC"
+ # Here, AUPR is computed using trapezoidal rule. Other methods available is a step-function approximation
+ heights = (gsPrecisions[2:end] + gsPrecisions[1:end-1])/2
+ widths = gsRecalls[2:end] - gsRecalls[1:end-1]
+ gsAuprs = sum(heights .* widths)
+ #=
+ # Step-function approximation. This works but is less robust
+ gsAuprs = 0.0
+ prev_recall = 0.0
+ for (r, p) in zip(gsRecalls, gsPrecisions)
+ delta = r - prev_recall
+ gsAuprs += delta * p
+ prev_recall = r
+ end
+ =#
+
+ # PARTIAL AUPR @ partialAUPRlimit
+ indx = findall(r -> r <= partialAUPRlimit, gsRecalls) # Find indices of recalls <= partialUpperLimRecall
+ recalls_sub = gsRecalls[indx]
+ precisions_sub = gsPrecisions[indx]
+ heights_sub = (precisions_sub[2:end] + precisions_sub[1:end-1]) ./ 2
+ widths_sub = recalls_sub[2:end] - recalls_sub[1:end-1]
+ partialAUPR = sum(heights_sub .* widths_sub)
+
+ # AROC : Trapezoidal Rule
+ widthsRoc = gsFprs[2:end] - gsFprs[1:end-1] # Change in FPR (the x-axis) between successive points.
+ heightsRoc = (gsRecalls[2:end] + gsRecalls[1:end-1]) / 2 # Average TPR (recall) for each segment.
+ gsArocs = sum(widthsRoc .* heightsRoc)
+ #=
+ gsArocs = 0.0
+ prev_fpr = 0.0
+ for (f, r) in zip(fprs, recalls)
+ delta = f - prev_fpr
+ gsArocs += delta * r
+ prev_fpr = f
+ end
+ =#
+
+ # ----- Part 8. Save Directory
+ baseName = splitext(basename(infTrnFile))[1]
+ if saveDir === nothing || isempty(saveDir)
+ dateStr = Dates.format(now(), "yyyymmdd_HHMMSS")
+ saveDir = joinpath(pwd(), baseName * "_" * dateStr)
+ end
+ mkpath(saveDir)
+
+ # ----- Part 9. Per-TF and Macro Metrics
+ resultsTF = nothing
+ perTFDict = nothing
+ macroPRDict = nothing
+ macroROCDict = nothing
+ if doPerTF
+ if !isempty(regs) && !isempty(targs) && !isempty(gsEdgesByTf)
+ resultsTF = evaluatePerTF(
+ regs, targs, rankings, infEdges, gsEdgesByTf, uniGsRegs, gsRandPRbyTf;
+ saveDir,
+ breakTies=breakTies, target_points=target_points, min_step=min_step,
+ step_method=step_method, xLimitRecall=xLimitRecall
+ )
+ # Extract Per-TF Results
+ perTFDict = resultsTF[:perTF]
+ macroPRDict = resultsTF[:macroPR]
+ macroROCDict = resultsTF[:macroROC]
+ else
+ @warn "Skipping per-TF evaluation: no valid edges after filtering"
+ end
+ end
+
+
+ # ----- Part 10. Plot and Save Results
+ suffix = breakTies ? "_tiesBroken" : ""
+ baseName = baseName * suffix
+ plotPRCurve(gsRecalls, gsPrecisions, gsRandPR, saveDir; xLimitRecall=xLimitRecall, baseName = baseName)
+ plotROCCurve(gsFprs, gsRecalls, saveDir; baseName = baseName)
+
+
+ results = OrderedDict(
+ :gsRegs => uniGsRegs,
+ :gsTargs => uniGsTargs,
+ :gsEdges => gsEdges,
+ :randPR => gsRandPR,
+ :infRegs => regs,
+ :infEdges => infEdges,
+ :breakTies => breakTies,
+ :stepVals => rankings,
+ :rankings => rankings,
+ # :commonEdgesBinaryVec => commonEdgesBinaryVec,
+ # :meanEdgesVec => breakTies ? meanEdgesVec : nothing,
+ :edgesVec => edgesVec,
+ :precisions => gsPrecisions,
+ :recalls => gsRecalls,
+ :fprs => gsFprs,
+ # :auprs => gsAuprs,
+ :auprs => Dict(
+ :full => gsAuprs,
+ :partial => Dict(
+ :value => partialAUPR,
+ :recallLimit => partialAUPRlimit
+ )
+ ),
+ :arocs => gsArocs,
+ :f1scores => gsF1scores,
+ # --- Include results from evaluatePerTF ---
+ :perTF => perTFDict,
+ :macroPR => macroPRDict,
+ :macroROC => macroROCDict
+ )
+
+ # Define a standard filename for saving the performance metrics
+ savedFile = joinpath(saveDir, baseName * "_PerformanceMetric.jld")
+ @save savedFile results
+
+ # Add the saved file path to the results, so the caller immediately knows it
+ results[:savedFile] = savedFile
+
+ return results
+end
\ No newline at end of file
diff --git a/src/metrics/Constants.jl b/src/metrics/Constants.jl
new file mode 100755
index 0000000..8314a21
--- /dev/null
+++ b/src/metrics/Constants.jl
@@ -0,0 +1,9 @@
+# const GS_REQUIRED_COLS = [:TF, :Target, :Weight]
+# const TRN_REQUIRED_COLS = [:TF, :Target, :Score]
+const MIN_REQUIRED_COLS = 3
+
+const DEFAULT_COLORS = [
+ "#377eb8", "#ff7f00", "#4daf4a", "#e41a1c", "#984ea3", "#a65628",
+ "#f781bf", "#00ced1", "#000000", "#5A9D5A", "#D96D3B", "#FFAD12",
+ "#66628D", "#91569A", "#B6742A", "#DD87B4", "#D26D7A", "#DAA520", "#dede00"
+]
\ No newline at end of file
diff --git a/src/metrics/MetricUtils.jl b/src/metrics/MetricUtils.jl
new file mode 100755
index 0000000..cf53659
--- /dev/null
+++ b/src/metrics/MetricUtils.jl
@@ -0,0 +1,96 @@
+"""
+ adaptiveStep(xs; target_points=1000, min_step=1e-4, method=:min_gap)
+
+Select a step size for interpolation based on actual values in `xs`.
+
+# Arguments
+- `xs`: Array of arrays (per TF) or single array of values.
+- `target_points`: Desired number of points if `method=:target_points`.
+- `min_step`: Minimum allowable step size.
+- `method`: `:min_gap` uses smallest nonzero gap; `:target_points` uses range/target_points.
+"""
+function adaptiveStep(xs; target_points=1000, min_step=1e-4, method=:min_gap)
+ vals = isa(xs[1], AbstractArray) ? reduce(vcat, xs) : xs
+ vals = sort(unique(vals))
+
+ if method == :min_gap
+ gaps = diff(vals)
+ nonzero_gaps = filter(g -> g > 0, gaps)
+ step = isempty(nonzero_gaps) ? min_step : max(minimum(nonzero_gaps), min_step)
+ elseif method == :target_points
+ range_span = maximum(vals) - minimum(vals)
+ step = max(range_span / target_points, min_step)
+ else
+ error("Unknown method: $method")
+ end
+
+ return step
+end
+
+
+"""
+ dedupNInterpolate(xs, ys, interpPts)
+
+Linearly interpolate `ys` onto `interpPts`, handling duplicates and invalid values.
+
+Filters `NaN`/`Inf`, sorts by `xs`, collapses duplicate `xs` by keeping max `y`,
+then interpolates using `Flat()` extrapolation. Returns zeros if fewer than 2 valid points.
+"""
+function dedupNInterpolate(xs::Vector{Float64}, ys::Vector{Float64}, interpPts::AbstractVector{Float64})
+ # Filter out invalid values
+ valid_inds = isfinite.(xs) .& isfinite.(ys)
+ xs = xs[valid_inds]
+ ys = ys[valid_inds]
+
+ # Check if there are enough points to interpolate
+ if length(xs) < 2
+ return zeros(length(interpPts))
+ end
+
+ # Sort xs and corresponding ys
+ perm = sortperm(xs)
+ xsSorted = xs[perm]
+ ysSorted = ys[perm]
+
+ # Collapse duplicates: keep max y for each unique x
+ dict = Dict{Float64, Float64}()
+ for (x, y) in zip(xsSorted, ysSorted)
+ dict[x] = haskey(dict, x) ? max(dict[x], y) : y
+ end
+
+ xsUnique = sort(collect(keys(dict)))
+ ysUnique = [dict[x] for x in xsUnique]
+
+ # Handle edge case: all ys are identical (LinearInterpolation still works)
+ itp = LinearInterpolation(xsUnique, ysUnique, extrapolation_bc=Flat())
+ return itp.(interpPts)
+end
+
+"""
+ validateColumns(cols, required, filename)
+
+Check that all `required` column names are present in `cols`.
+Throws an informative error naming the missing columns if any are absent.
+
+usage:
+validateColumns(Symbol.(propertynames(gsData)), GS_REQUIRED_COLS, gsFile)
+"""
+# function validateColumns(cols, required, filename)
+# missing_cols = setdiff(required, cols)
+# if !isempty(missing_cols)
+# error("""
+# Missing columns in $filename: $(join(missing_cols, ", "))
+# Required columns: $(join(required, ", "))
+# Found columns: $(join(cols, ", "))
+# """)
+# end
+# end
+
+function validateColumnCount(data, filename)
+ if length(propertynames(data)) < MIN_REQUIRED_COLS
+ error("""
+ $filename must have at least 3 columns in order: TF, Target, Score
+ Found: $(length(propertynames(data))) column(s)
+ """)
+ end
+end
\ No newline at end of file
diff --git a/src/metrics/Metrics.jl b/src/metrics/Metrics.jl
new file mode 100755
index 0000000..80324b6
--- /dev/null
+++ b/src/metrics/Metrics.jl
@@ -0,0 +1,33 @@
+using JLD2, InlineStrings
+using PyPlot
+using Colors
+using Dates
+using DataFrames
+using OrderedCollections
+using Measures
+using CSV
+using Base.Threads
+using Interpolations
+using Statistics
+
+# ----------------------------
+# Plot defaults
+# ----------------------------
+
+const dpi = 600
+
+function setPlotDefaults!()
+ rc = PyPlot.matplotlib.rcParams
+
+ rc["font.family"] = "Nimbus Sans"
+ rc["axes.titlesize"] = 9
+ rc["axes.labelsize"] = 9
+ rc["xtick.labelsize"] = 7
+ rc["ytick.labelsize"] = 7
+ rc["legend.fontsize"] = 9
+
+ rc["figure.dpi"] = dpi
+ rc["savefig.dpi"] = dpi
+
+ return nothing
+end
\ No newline at end of file
diff --git a/src/metrics/plotting/PlotBatch.jl b/src/metrics/plotting/PlotBatch.jl
new file mode 100755
index 0000000..4d1946f
--- /dev/null
+++ b/src/metrics/plotting/PlotBatch.jl
@@ -0,0 +1,928 @@
+# ----------------------------------------------------------------------
+# Color Handling
+# ----------------------------------------------------------------------
+
+"""
+ padColors(listFiles)
+
+Ensures enough colors for plotting by extending a predefined palette.
+Returns a vector of hex color codes, generating random colors if needed.
+
+# Arguments
+- `listFiles`: Vector of items needing unique colors.
+"""
+function padColors(listFiles)
+ lineColors = copy(DEFAULT_COLORS)
+ lenFile = length(listFiles)
+ lenColor = length(lineColors)
+ excessCT = lenFile - lenColor
+
+ if excessCT > 0
+ println("Warning: There are more entries in listFilePR than available colors. Generating random colors for the excesses")
+ colorGen = [hex(RGB(rand(), rand(), rand())) for _ in 1:(excessCT + 12)] # Generate random colors for plots
+ uniqueCols = setdiff(colorGen, lineColors)
+ # Pad lineColors with the unique additional colors
+ lineColors = vcat(lineColors, uniqueCols[1:excessCT])
+ end
+ return lineColors
+end
+
+# ----------------------------------------------------------------------
+# Helper: Data Loading
+# ----------------------------------------------------------------------
+# function loadPRData(source)
+# if isa(source, String)
+# try
+# data = load(source)
+# return Dict(
+# :precisions => data["results"][:precisions],
+# :recalls => data["results"][:recalls],
+# :randPR => get(data["results"], :randPR, nothing)
+# )
+# catch e
+# @warn "Could not load $source" exception=(e, catch_backtrace())
+# return nothing
+# end
+# elseif isa(source, Dict) || isa(source, OrderedDict)
+# if haskey(source, :precisions) && haskey(source, :recalls)
+# return Dict(:precisions=>source[:precisions], :recalls=>source[:recalls],
+# :randPR=>get(source, :randPR, nothing))
+# elseif haskey(source, "results")
+# r = source["results"]
+# return Dict(:precisions=>r[:precisions], :recalls=>r[:recalls],
+# :randPR=>get(r, :randPR, nothing))
+# end
+# end
+# return nothing
+# end
+
+function loadPRData(source; mode::Symbol=:global)
+ # Load file safely if source is a path
+ r = if isa(source, String)
+ try
+ load(source)["results"]
+ catch e
+ @warn "Could not load $source" exception=(e, catch_backtrace())
+ return nothing
+ end
+ else
+ source
+ end
+
+ if r === nothing
+ return nothing
+ end
+
+ # Determine mode
+ if mode == :macro
+ if !haskey(r, :macroPR)
+ @warn "No macroPR found, falling back to global"
+ r_macro = r
+ else
+ r_macro = r[:macroPR]
+ end
+ return Dict(
+ :precisions => r_macro[:precisions],
+ :recalls => r_macro[:recalls],
+ :randPR => get(r_macro, :randPR, get(r, :randPR, nothing)) # fallback to global randPR
+ )
+ elseif mode == :global
+ return Dict(
+ :precisions => r[:precisions],
+ :recalls => r[:recalls],
+ :randPR => get(r, :randPR, nothing)
+ )
+ elseif mode == :perTF
+ if !haskey(r, :perTF)
+ @warn "No perTF data found"
+ return nothing
+ end
+ return r[:perTF]
+ else
+ error("Unknown mode: $mode")
+ end
+end
+
+# ----------------------------------------------------------------------
+# Transformation Helper
+# ----------------------------------------------------------------------
+"""
+ function makeTransform(yScale::String)
+Creates forward and inverse transformation functions for plot scaling.
+
+# Arguments
+- `yScale::String`: Transformation type ("sqrt", "cubert", or "linear")
+
+Returns a tuple of (forward, inverse) functions. Defaults to identity functions if scale type not recognized.
+ """
+function makeTransform(yScale::String)
+ if yScale == "linear"
+ forwardFunc = x -> x
+ inverseFunc = x -> x
+ elseif yScale == "sqrt"
+ forwardFunc = x -> sqrt.(x)
+ inverseFunc = x -> x .^ 2
+ elseif yScale == "cubert"
+ forwardFunc = x -> x .^ (1/3)
+ inverseFunc = x -> x .^ 3
+ else
+ # For "linear" or any unknown yScale, return identity functions
+ @warn "Unknown scale type: $yScale. Using linear scale."
+ forwardFunc = x -> x
+ inverseFunc = x -> x
+ end
+ return forwardFunc, inverseFunc
+end
+
+# ----------------------------------------------------------------------
+# Plotting Helpers
+# ----------------------------------------------------------------------
+
+# # Helper function: Plot data on one or more axes.
+# function plotFileData!(axes, legendLabel::String, dataSource, color; lineType=nothing)
+# """
+# Plots precision-recall curves on specified axes from data file or direct data.
+
+# # Arguments
+# - `axes`: Array of plot axes to draw on
+# - `legendLabel::String`: Label for plot legend
+# - `dataSource`: Either a file path (String) or a Dict containing PR data
+# - `color`: Color for plotting
+# - `lineType=nothing`: Optional line style specification
+
+# # Returns
+# Loaded/processed data or nothing if processing fails.
+# """
+
+# # Initialize variables to hold the data
+# precisions = nothing
+# recalls = nothing
+# randPR = nothing
+# fileData = nothing
+
+# # Process the data source based on its type
+# if isa(dataSource, String)
+# # It's a file path - try to load it
+# try
+# fileData = load(dataSource)
+# # Extract data from the loaded file
+# precisions = fileData["results"][:precisions]
+# recalls = fileData["results"][:recalls]
+# randPR = get(fileData["results"], :randPR, nothing)
+# catch e
+# println("Error loading file: $dataSource - $e")
+# return nothing
+# end
+# elseif isa(dataSource, Dict) || isa(dataSource, OrderedDict)
+# # It's already a data dictionary
+# fileData = dataSource
+
+# # Check if it's a results dictionary or a direct data dictionary
+# if haskey(dataSource, :precisions) && haskey(dataSource, :recalls)
+# # Direct data format
+# precisions = dataSource[:precisions]
+# recalls = dataSource[:recalls]
+# randPR = get(dataSource, :randPR, nothing)
+# elseif haskey(dataSource, "results")
+# # Nested results format (like from a loaded JLD file)
+# precisions = dataSource["results"][:precisions]
+# recalls = dataSource["results"][:recalls]
+# randPR = get(dataSource["results"], :randPR, nothing)
+# else
+# println("Error: Data dictionary does not contain required precision/recall data")
+# return nothing
+# end
+# else
+# println("Error: Unsupported data source type: $(typeof(dataSource))")
+# return nothing
+# end
+
+# # Ensure we have valid data before plotting
+# if isnothing(precisions) || isnothing(recalls)
+# println("Error: Could not extract precision/recall data from source")
+# return nothing
+# end
+
+# # Plot the data on each provided axis
+# for ax in axes
+# if isnothing(lineType) || lineType == ""
+# # Use default linestyle
+# ax.plot(recalls, precisions, label=legendLabel, color=color, linewidth=0.8)
+# else
+# # Specify the provided line type (linestyle)
+# ax.plot(recalls, precisions, label=legendLabel, color=color, linewidth=0.5, linestyle=lineType)
+# end
+# end
+
+# return Dict(
+# # "data" => fileData,
+# "precisions" => precisions,
+# "recalls" => recalls,
+# "randPR" => randPR
+# )
+# end
+
+function plotFileData!(axes::AbstractVector, source, label::String, color::String; lineType=nothing, lineWidth=nothing, mode::Symbol=:global)
+ data = loadPRData(source; mode)
+ if isnothing(data)
+ @warn "Skipping $label ($source) — no valid PR data found"
+ return false
+ end
+
+ for ax in axes
+ ax.plot(data[:recalls], data[:precisions], label=label, color=color,linewidth=(lineWidth === nothing ? 0.7 : lineWidth),
+ linestyle=(lineType === nothing ? "-" : lineType))
+ end
+ return data
+end
+
+
+# Function to combine legend handles and labels from every axis.
+function combineLegends(fig)
+ allHandles = Any[]
+ allLabels = String[]
+ # Loop over each axis in the figure.
+ for ax in fig.axes
+ handles, labels = ax.get_legend_handles_labels()
+ for (h, l) in zip(handles, labels)
+ if !(l in allLabels)
+ push!(allHandles, h)
+ push!(allLabels, l)
+ end
+ end
+ end
+ return allHandles, allLabels
+end
+
+function styleAxis!(ax; xZoom=0.1, yRange::Union{Nothing,Tuple}=nothing,
+ xStep=0.05, yStep=0.1, yScale="linear", setXticks=true, tickLabelSize::Int = 7)
+# function styleAxis!(ax; xZoom, yRange, xStep, yStep, yScale, setXticks, tickLabelSize)
+ # X-axis
+ ax.set_xlim(0, xZoom)
+ if setXticks
+ xMax = mod(xZoom, xStep) == 0 ? xZoom : ceil(xZoom/xStep) * xStep
+ ax.set_xticks(0:xStep:xMax)
+ end
+
+ # Y-axis
+ forwardFunc, inverseFunc = makeTransform(yScale)
+ yMin, yMax = isnothing(yRange) ? (0.0, 1.0) : yRange
+ ax.set_ylim(yMin, yMax)
+
+ # Y ticks
+ yTicks = yMin:yStep:yMax
+ ax.set_yticks(yTicks)
+ ax.set_yticklabels(string.(round.(yTicks; digits=2)))
+ ax.set_yscale("function", functions=(forwardFunc, inverseFunc))
+
+ # Labels
+ # ax.set_xlabel(xlabel, fontsize=axisTitleSize, family="Nimbus Sans")
+ # ax.set_ylabel(ylabel, fontsize=axisTitleSize, family="Nimbus Sans")
+
+ # Grid & ticks
+ ax.grid(true, which="major", linestyle="-", linewidth=0.5, color="lightgray")
+ ax.minorticks_on()
+ ax.grid(true, which="minor", linestyle=":", linewidth=0.25, color="lightgray")
+ ax.tick_params(axis="both", which="both", labelsize=tickLabelSize, direction="out")
+end
+
+"""
+ plotPRCurves(listFilePR, dirOut, saveName; kwargs...)
+
+Plot and save precision-recall curves for one or more networks against a gold standard.
+Supports single-axis and broken y-axis modes. Saves both legend and no-legend versions.
+
+# Arguments
+- `listFilePR` : `OrderedDict` mapping legend labels to either file paths or
+ `OrderedDict`s with `:precisions`, `:recalls`, and `:randPR` keys.
+- `dirOut::String` : Output directory for saved plots.
+- `saveName::String`: Base name for output files.
+
+# Keyword Arguments
+- `xLimitRecall::Float64=0.1` : Maximum recall shown on x-axis.
+- `yZoomPR::Vector=[]` : Y-axis display range.
+ `[]` → full range 0–1.
+ `[0.8]` → limit y-axis to 0–0.8.
+ `[0.4, 0.9]` → broken y-axis with gap between 0.4 and 0.9.
+- `xStepSize::Union{Nothing,Real}=nothing` : X-axis tick step size. Default auto-set to 0.05.
+- `yStepSize::Union{Nothing,Real}=nothing` : Y-axis tick step size. Default auto-set to 0.2.
+- `yScale::String="linear"` : Y-axis scale — `"linear"`, `"sqrt"`, or `"cubert"`.
+- `isInside::Bool=true` : Place legend inside (`true`) or outside (`false`) the plot.
+- `lineColors::Vector=[]` : Custom line colors. Empty uses default palette.
+- `lineTypes::Vector=[]` : Custom line styles. Empty uses solid lines.
+- `lineWidths::Vector=[]` : Custom line widths. Empty uses default width.
+- `heightRatios=nothing` : Height ratios for broken y-axis panels `[topHeight, bottomHeight]`.
+ `nothing` --> auto-computed proportionally from `yZoomPR` values.
+ Pass explicit values e.g. `[3.0, 0.5]` to control panel sizes manually —
+ larger value = taller panel.
+- `mode::Symbol=:global` : PR data to use:
+ `:global` → overall PR curve.
+ `:macro` → macro-averaged PR (falls back to global if missing).
+ `:perTF` → per-TF PR curves.
+
+# Returns
+Path to the saved plot file.
+
+# Example
+```julia
+plotPRCurves(listFilePR, "plots", "MyNetwork";
+ xLimitRecall=0.1, yZoomPR=[0.4, 0.9],
+ isInside=false, mode=:macro)
+```
+"""
+function plotPRCurves(listFilePR, dirOut::String, saveName::String;
+ xLimitRecall = 0.1, yZoomPR = [], xStepSize::Union{Nothing,Real}=nothing,
+ yStepSize::Union{Nothing,Real}=nothing,
+ yScale::String="linear", isInside::Bool=true,
+ lineColors=[], # empty vector means "use default"
+ lineTypes=[], # empty vector means "use default",
+ lineWidths = [],
+ heightRatios=[0.5, 3.0],
+ mode::Symbol = :global
+ )
+
+ # Defaults
+ xStep = isnothing(xStepSize) ? 0.05 : xStepSize
+ yStep = isnothing(yStepSize) ? 0.2 : yStepSize
+ dirOut = isempty(dirOut) ? pwd() : mkpath(dirOut)
+ # Style parameters
+ axisTitleSize = 9
+ tickLabelSize = 7
+ # plotTitleSize = 16
+ legendSize = 9
+
+ # Color assignment
+ if isempty(lineColors)
+ lineColors = padColors(listFilePR)
+ end
+
+ # Auto height ratios if broken axis
+ if length(yZoomPR) == 2 && heightRatios === nothing
+ lowerSpan = max(yZoomPR[1] - 0.0, 0.0) # bottom panel span (0 -> yZoomPR[1])
+ upperSpan = max(1.0 - yZoomPR[2], 0.0) # top panel span (yZoomPR[2] -> 1)
+ # protect against zero spans
+ if upperSpan == 0.0
+ upperSpan = 1e-6
+ end
+ if lowerSpan == 0.0
+ lowerSpan = 1e-6
+ end
+ # gridspec expects [top_height, bottom_height]
+ heightRatios = [upperSpan, lowerSpan]
+ end
+ # lastPlotData = nothing # will store last loaded file data (for randPR)
+ # println(">>> Inside plotPRCurves: yZoomPR = ", yZoomPR, " length=", length(yZoomPR))
+
+ ## Making Plots
+ if isempty(yZoomPR) || length(yZoomPR) == 1
+ yZoom = isempty(yZoomPR) ? 1.0 : yZoomPR[1]
+
+ # Create a Single PyPlot figure
+ fig, ax = subplots(figsize= isInside ? (2, 2) : (2, 2), layout="constrained") #5,4
+ styleAxis!(ax; xZoom=xLimitRecall, yRange=(0,yZoom), xStep=xStep, yStep=yStep, yScale=yScale)
+
+ @info "Plotting PR curves"
+ lastPlotData = nothing
+ for (idx, (legendLabel, currFilePR)) in enumerate(listFilePR)
+ @info "Plot $idx: $legendLabel" file=currFilePR
+ # If lineType is provided and nonempty, then pick its element if available.
+ currentLineType = (length(lineTypes) ≥ idx && lineTypes[idx] != "") ? lineTypes[idx] : nothing
+ currentLineWidth = (length(lineWidths) ≥ idx &&lineWidths[idx] != "") ? lineWidths[idx] : nothing
+ lastPlotData = plotFileData!([ax], currFilePR, legendLabel, lineColors[idx];
+ lineType=currentLineType, lineWidth = currentLineWidth, mode)
+ end
+
+ # Plot random PR line from the last file if available
+ if lastPlotData !== nothing && lastPlotData[:randPR] !== nothing
+ randPR = lastPlotData[:randPR]
+ ax.axhline(randPR, linestyle="-.", linewidth=2, color=[0.6, 0.6, 0.6], label="Random")
+ end
+
+ # # Set labels and legend
+ ax.set_xlabel("Recall", fontsize=axisTitleSize, family="Nimbus Sans")
+ ax.set_ylabel("Precision", fontsize=axisTitleSize, family="Nimbus Sans")
+ # save figure without legend
+ if saveName !== nothing && !isempty(saveName)
+ savePath = joinpath(dirOut, string(saveName, "_PR_noLegend.pdf"))
+ else
+ savePath = joinpath(dirOut, string(Dates.format(now(), "yyyymmdd_HHMMSS"), "_PR_noLegend.pdf"))
+ end
+ PyPlot.savefig(savePath, dpi=600)
+
+ # Add Figure Legends
+ if isInside
+ ax.legend(borderaxespad=0.2, frameon=true)
+ else
+ ax.legend(loc="center", bbox_to_anchor=(1.25, 0.5), borderaxespad=0.2, frameon=false)
+ fig.set_size_inches(3.5, 2.5) # wider only for outside legend
+ end
+ savePathLegend = replace(savePath, "PR_noLegend.pdf" => "PR.pdf")
+ PyPlot.savefig(savePathLegend, dpi=600)
+
+ elseif length(yZoomPR) == 2
+ # # ---- Two-subplot (broken y-axis) mode ----
+ fig, (ax1, ax2) = subplots(2, 1, sharex=true, figsize=(2,2), gridspec_kw=Dict("heightRatios" => heightRatios,
+ "hspace" => 0.0), layout="constrained") #6,5
+
+ # Upper axis
+ styleAxis!(ax1; xZoom=xLimitRecall, yRange=(yZoomPR[2],1.0), xStep=xStep, yStep=yStep, yScale=yScale, setXticks=false)
+ # Lower axis
+ styleAxis!(ax2; xZoom=xLimitRecall, yRange=(0,yZoomPR[1]), xStep=xStep, yStep=yStep, yScale=yScale)
+ ax2.set_xlabel("Recall", fontsize=9)
+
+ # Plot curves
+ lastPlotData = nothing
+ @info "Plotting PR curves"
+ for (idx, (legendLabel, currFilePR)) in enumerate(listFilePR)
+ @info "Plot $idx: $legendLabel; File: $currFilePR"
+ # If lineType is provided and nonempty, then pick its element if available.
+ currentLineType = (length(lineTypes) ≥ idx && lineTypes[idx] != "") ? lineTypes[idx] : nothing
+ currentLineWidth = (length(lineWidths) ≥ idx &&lineWidths[idx] != "") ? lineWidths[idx] : nothing
+ lastPlotData = plotFileData!([ax1, ax2], currFilePR, legendLabel, lineColors[idx];
+ lineType=currentLineType, lineWidth = currentLineWidth, mode)
+
+ end
+
+ # Extract randPR from the last file
+ # Plot the random PR line on both axes if available.
+ if lastPlotData !== nothing && lastPlotData[:randPR] !== nothing
+ randPR = lastPlotData[:randPR]
+ # println(randPR)
+ for ax in (ax1, ax2)
+ ax.axhline(randPR, linestyle="-.", linewidth=2, color=[0.6, 0.6, 0.6], label="Random")
+ end
+ end
+
+ # Hide the spines between ax1 and ax2
+ ax1.spines["bottom"].set_visible(false)
+ ax2.spines["top"].set_visible(false)
+ ax1.tick_params(axis="x", which="both", bottom=false, top=false, labelbottom=false, labeltop=false)
+ ax2.xaxis.tick_bottom()
+
+ # Add break indicators
+ d = 0.006 # Size of the diagonal lines in axes coordinates
+ # Top-left and top-right diagonals for ax1
+ ax1.plot([-d, +d], [-d, +d]; transform=ax1.transAxes, color="k", clip_on=false) # Top-left diagonal
+ ax1.plot([1 - d, 1 + d], [-d, +d]; transform=ax1.transAxes, color="k", clip_on=false) # Top-right diagonal
+ # Bottom-left and bottom-right diagonals for ax2
+ ax2.plot([-d, +d], [1 - d, 1 + d]; transform=ax2.transAxes, color="k", clip_on=false) # Bottom-left diagonal
+ ax2.plot([1 - d, 1 + d], [1 - d, 1 + d]; transform=ax2.transAxes, color="k", clip_on=false) # Bottom-right diagonal
+
+ # Common labels: a shared y-axis label and x-axis label on the lower plot.
+ # fig.text(0.01, 0.5, "Precision", va="center", rotation="vertical", fontsize=axisTitleSize)
+ fig.supylabel("Precision", fontsize=axisTitleSize)
+ ax2.set_xlabel("Recall", fontsize=axisTitleSize)
+
+
+ # save figure without legend
+ if saveName !== nothing && !isempty(saveName)
+ savePath = joinpath(dirOut, string(saveName, "_PR_noLegend.pdf"))
+ else
+ savePath = joinpath(dirOut, string(Dates.format(now(), "yyyymmdd_HHMMSS"), "_PR_noLegend.pdf"))
+ end
+ PyPlot.savefig(savePath, dpi=600)
+
+ # Add Figure Legends
+ if isInside
+ ax2.legend(fontsize= legendSize, borderaxespad=0.2, frameon=true)
+ else
+ ax2.legend(fontsize= legendSize, loc="center", bbox_to_anchor=(1.25, 0.5), borderaxespad=0.2, frameon=false)
+ fig.set_size_inches(5, 4) # wider only for outside legend
+ end
+
+ savePathLegend = replace(savePath, "PR_noLegend.pdf" => "PR.pdf")
+ PyPlot.savefig(savePathLegend, dpi=600)
+
+ end
+ PyPlot.close("all")
+end
+
+
+# -------------------------------------------------------------------------------------
+# Part 2: Making a DotPlot or BarPlot of AUPR using PyPlot
+# -------------------------------------------------------------------------------------
+function plotAUPR(gsParam::OrderedDict{String,<:AbstractDict}, dirOut::String;
+ saveName::Union{Nothing, String} = nothing,
+ metricType::String="full",
+ figSize::Tuple{Real,Real}=(5,5),
+ axisTitleSize::Int=9, tickLabelSize::Int=7,
+ legendFontSize::Int=9, tickRotation::Int=45,
+ plotType::String="",
+ saveLegend::Bool=true)
+
+ # ----------------------
+ # Configure matplotlib
+ # rc = PyPlot.matplotlib["rcParams"]
+ # rc["font.family"] = "Nimbus Sans"
+ # rc["axes.labelsize"] = axisTitleSize
+ # rc["xtick.labelsize"] = tickLabelSize
+ # rc["ytick.labelsize"] = tickLabelSize
+ # rc["legend.fontsize"] = legendFontSize
+
+
+ # Ensure output directory exists
+ dirOut = joinpath(dirOut, "AUPR")
+ mkpath(dirOut)
+
+ # Load AUPR values
+ function loadAupr(filePath::String, metricType::String)
+ try
+ data = load(filePath)
+ auprs = data["results"][:auprs]
+ return lowercase(metricType) == "partial" ? auprs[:partial][:value] : auprs[:full]
+ catch e
+ @warn "Could not load AUPR from $filePath" exception=(e, catch_backtrace())
+ return nothing
+ end
+ end
+
+ lineColors = copy(DEFAULT_COLORS)
+ # Build dataframe
+ gsNames = collect(keys(gsParam))
+ netNames = unique(vcat([collect(keys(files)) for files in values(gsParam)]...))
+
+ dfRows = Vector{NamedTuple{(:xGroups,:Network,:AUPR),Tuple{String,String,Float64}}}()
+ for gs in gsNames, net in netNames
+ filePath = haskey(gsParam[gs], net) ? gsParam[gs][net] : nothing
+ # And update the caller to skip nothing values
+ auprVal = filePath === nothing ? nothing : loadAupr(filePath, metricType)
+ push!(dfRows, (xGroups=gs, Network=net, AUPR=something(auprVal, NaN))) # NaN is visible in plot
+ # push!(dfRows, (xGroups=gs, Network=net, AUPR=auprVal))
+ end
+
+ df = DataFrame(dfRows)
+ numGS = length(unique(df.xGroups))
+ numNet = length(unique(df.Network))
+
+ # Decide plot type automatically if user hasn't overridden
+ autoPlotType = ""
+ if lowercase(plotType) in ["bar","dot"]
+ autoPlotType = lowercase(plotType)
+ else
+ autoPlotType = (numGS == 1 || numNet == 1) ? "dot" : "bar"
+ end
+
+ xLabelVal = (numGS == 1 ? "Network" : "Gold Standards")
+ yLabelVal = lowercase(metricType) == "partial" ? "Partial AUPR" : "AUPR"
+
+ # Function to generate a plot
+ function make_plot(includeLegend::Bool)
+ fig, ax = plt.subplots(figsize=figSize, layout="constrained")
+ colors = lineColors[1:numNet]
+
+ if autoPlotType == "bar"
+ if numGS == 1
+ for (i, net) in enumerate(netNames)
+ idxs = findall(df.Network .== net)
+ ax.bar.(i, df.AUPR[idxs], color=colors[i], label=(includeLegend ? net : nothing))
+ end
+ ax.set_xticks(1:numNet)
+ ax.set_xticklabels(netNames, rotation=tickRotation)
+ elseif numNet == 1
+ ax.bar.(1:numGS, df.AUPR, color=colors[1], label=(includeLegend ? netNames[1] : nothing))
+ ax.set_xticks(1:numGS)
+ ax.set_xticklabels(gsNames, rotation=tickRotation)
+ else
+ barWidth = 0.15
+ groupWidth = numNet * barWidth
+ groupSpacing = 0.4
+ groupPositions = [i*(groupWidth + groupSpacing) for i in 0:(numGS-1)]
+ for j in 1:numNet
+ offset = (j-(numNet+1)/2)*barWidth
+ xPos = [gp + groupWidth/2 + offset for gp in groupPositions]
+ auprs = [df.AUPR[(i-1)*numNet + j] for i in 1:numGS]
+ ax.bar(xPos, auprs, barWidth, color=colors[j], label=(includeLegend ? netNames[j] : nothing))
+ end
+ ax.set_xticks([gp + groupWidth/2 for gp in groupPositions])
+ ax.set_xticklabels(gsNames, rotation=tickRotation)
+ end
+ elseif autoPlotType == "dot"
+ if numGS == 1
+ for (i, net) in enumerate(netNames)
+ idxs = findall(df.Network .== net)
+ ax.scatter.(i, df.AUPR[idxs], color=colors[i], s=40, label=(includeLegend ? net : nothing))
+ end
+ ax.set_xticks(1:numNet)
+ ax.set_xticklabels(netNames, rotation=tickRotation)
+ elseif numNet == 1
+ ax.scatter.(1:numGS, df.AUPR, color=colors[1], s=40, label=(includeLegend ? netNames[1] : nothing))
+ ax.set_xticks(1:numGS)
+
+ ax.set_xticklabels(gsNames, rotation=tickRotation)
+ else
+ for j in 1:numNet
+ xvals = 1:numGS
+ yvals = [df.AUPR[(i-1)*numNet + j] for i in 1:numGS]
+ ax.scatter.(xvals, yvals, color=colors[j], s=40, label=(includeLegend ? netNames[j] : nothing))
+ end
+ ax.set_xticks(1:numGS)
+ ax.set_xticklabels(gsNames, rotation=tickRotation)
+ end
+ end
+
+ ax.set_xlabel(xLabelVal, fontsize=axisTitleSize)
+ ax.set_ylabel(yLabelVal, fontsize=axisTitleSize)
+ ax.tick_params(axis="both", which="both")
+ # ax.tick_params(axis="both", which="both", labelsize=tickLabelSize, direction="in")
+ ax.grid(true, which="major", linestyle="-", linewidth=0.5, color="lightgray")
+ if includeLegend
+ # ax.legend(fontsize=legendFontSize, loc="best")
+ ax.legend(fontsize=legendFontSize, loc="center", bbox_to_anchor=(1.25, 0.5), borderaxespad=0.2, frameon=false)
+ else
+ # Remove x-axis labels and ticks when no legend
+ ax.set_xticks([])
+ ax.set_xticklabels([])
+ ax.set_xlabel("")
+ end
+
+ return fig, ax
+ end
+
+ saveName = (saveName === nothing || isempty(saveName)) ? "AUPR_$(Dates.format(now(), "yyyymmdd_HHMMSS"))" : saveName
+ # Save plot according to saveLegend
+ savePath = joinpath(dirOut, saveName * "_$(autoPlotType)_" * (saveLegend ? "withLegend" : "noLegend") * ".pdf")
+ fig, _ = make_plot(saveLegend)
+ fig.savefig(savePath, dpi=600)
+ plt.close(fig)
+ println("Saved plot to:\n $savePath")
+end
+
+
+# function plotAUPR(gsParam::OrderedDict{String,<:AbstractDict}, dirOut::String, saveName::String; figSize::Tuple{Real, Real})
+
+# """
+# Generate a visualization of Area Under the Precision-Recall Curve (AUPR) values using either a bar plot or a dot plot, depending on the input data structure.
+# - If there are multiple gold standards and networks, a bar plot is created, with each group of bars representing a gold standard and each bar within a group representing a network.
+# - If there is only one gold standard or one network, a dot plot is created, with each dot representing a network or gold standard.
+# - The plot is saved as a PDF file in the specified `dirOut` directory, with a resolution of 600 DPI.
+
+
+# # Arguments
+# - `gsParam::OrderedDict{String,Dict{String,String}}`: Mapping of gold standards to network file paths.
+# - `dirOut::String`: Directory to save the plot.
+# - `saveName::String`: Base name for the output file.
+
+# # Keywords
+# - `figSize::Tuple{Real, Real}=(5, 5)`: Size of the figure.
+
+# # Returns
+# Nothing, but saves the plot to a file.
+# """
+
+# function loadAupr(filePath)
+# try
+# data = load(filePath)
+# return data["results"][:auprs]
+# catch e
+# println("Error loading $filePath: $e")
+# return nothing
+# end
+# end
+
+# # Check if dirOut is either nothing or an empty string
+# dirOut = if dirOut === nothing || isempty(dirOut)
+# pwd()
+# else
+# mkpath(dirOut)
+# end
+
+# if isempty(figSize) || length(figSize) == 1
+# println("Warning: figSize must be a tuple of 2 element")
+# figSize = (5,5)
+# end
+
+# axisTitleSize = 16
+# tickLabelSize = 14
+# plotTitleSize = 16
+# legendSize = 10
+
+# lineColors = [ #Initial Color palette
+# "#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628",
+# "#984ea3", "#e41a1c", "#00ced1", "#000000",
+# "#5A9D5A", "#D96D3B", "#FFAD12", "#66628D", "#91569A",
+# "#B6742A", "#DD87B4", "#D26D7A", "#dede00"]
+
+# # Load aupr values
+# gsNames = collect(keys(gsParam))
+# netNames = unique(vcat([collect(keys(files)) for files in values(gsParam)]...))
+# # auprValues = [loadAupr(gsParam[gs][netName]) !== nothing ? loadAupr(gsParam[gs][netName]) : 0.0 for gs in gsNames, netName in netNames]
+
+# auprValues = []
+# # Load AUPR values for each GS and file
+# for gs in gsNames
+# for netName in netNames
+# # Load the AUPR value
+# filePath = gsParam[gs][netName]
+# aupr = loadAupr(filePath)
+
+# # Use 0.0 if the AUPR value is missing
+# push!(auprValues, aupr === nothing ? 0.0 : aupr)
+# end
+# end
+# # Reshape auprValues into a matrix
+# auprMatrix = transpose(reshape(auprValues, length(netNames), length(gsNames)))
+
+# boolNet = any(length(netNames) > 1 for netNames in values(gsParam))
+# numNet = length(netNames)
+# numGS = length(gsNames)
+# # =
+# # ---- Load AUPR values (Works but not using)
+# # auprData = Dict(gs => Dict(name => loadAupr(path) for (name, path) in files) for (gs, files) in gsParam)
+# # netNames = unique(vcat([collect(keys(files)) for files in values(gsParam)]...))
+# # auprValues = [get(auprData[gs], netName, 0.0) for gs in gsNames, netName in netNames]
+# = #
+# if boolNet && numGS > 1
+# # Create a figure
+# fig, ax = plt.subplots(figsize=figSize, layout="constrained")
+# # Define parameters for grouping:
+# barWidth = 0.15 # Width of each bar.
+# groupSpacing = 0.4 # Extra space between groups.
+# groupWidth = numNet * barWidth # Total width occupied by bars in one group.
+
+# # Compute positions for each group.
+# # Compute a vector where each element is the left edge (or starting point) of each group.
+# # Use 0-based indexing for groups.
+# groupPositions = [i * (groupWidth + groupSpacing) for i in 0:(numGS-1)]
+
+# # Now plot each bar (each network) in every group.
+# # Use a centered offset for each bar in the group.
+# for idx in 1:numNet
+# # Calculate the offset to center the bar within the group.
+# # (idx - (numNet+1)/2)*barWidth shifts bars so they cluster around the center.
+# offset = (idx - (numNet+1)/2)*barWidth
+# # The x-positions for the bars are:
+# # groupPositions + offset + groupWidth/2
+# # Adding groupWidth/2 centers the bars in each group.
+# x_positions = [gp + groupWidth/2 + offset for gp in groupPositions]
+# # Plot the bar for curr network across all groups.
+# ax.bar(x_positions, auprMatrix[:, idx], barWidth, label=netNames[idx],
+# color = lineColors[mod1(idx, length(lineColors))])
+# end
+
+# # Set the xticks to the center positions of each group.
+# groupCenters = [gp + groupWidth/2 for gp in groupPositions]
+# ax.set_xticks(groupCenters)
+# ax.set_xticklabels(gsNames)
+# ax.set_xlabel("Gold Standards", fontsize=axisTitleSize)
+# ax.set_ylabel("AUPR", fontsize=axisTitleSize)
+# ax.tick_params(axis="both", which="both", labelsize=tickLabelSize)
+# ax.legend(fontsize=legendSize, loc="center left", bbox_to_anchor=(1, 0.5))
+
+# elseif numGS == 1 || numNet == 1
+# # Dot Plot
+# fig, ax = subplots(figsize= figSize, layout="constrained")
+# if numGS == 1
+# # One GS, multiple files
+# for idx in 1:numNet
+# ax.scatter(idx, auprMatrix[idx], color=lineColors[idx], label=netNames[idx])
+# ax.text(idx, auprMatrix[idx], netNames[idx], fontsize=9, ha="right")
+# end
+# ax.set_xlabel("Network", fontsize=axisTitleSize)
+# ax.set_xticks([])
+# ax.set_xticklabels([])
+# # ax.grid(true, which="major", linestyle="-", linewidth=0.5, color="lightgray")
+# else
+# for idx in 1:numGS
+# ax.scatter(idx, auprMatrix[idx], color=lineColors[idx], label=gsNames[idx])
+# ax.text(idx, auprMatrix[idx], gsNames[idx], fontsize=9, ha="right")
+# end
+# ax.set_xlabel("Gold Standard", fontsize=axisTitleSize)
+# ax.set_xticks([])
+# ax.set_xticklabels([])
+# # ax.grid(true, which="major", linestyle="-", linewidth=0.5, color="lightgray")
+# end
+# end
+
+# # Common settings for both plots.
+# ax.grid(true, which="major", linestyle="-", linewidth=0.5, color="lightgray")
+# ax.tick_params(axis="both", which="both", labelsize=tickLabelSize)
+# # plt.tight_layout()
+
+
+# if !isempty(saveName)
+# plt.savefig(joinpath(dirOut, saveName * "_AUPR.pdf"), dpi=600)
+# else
+# dateStr = Dates.format(now(), "yyyymmdd_HHMMSS")
+# plt..savefig(joinpath(dirOut, dateStr * "_AUPR.pdf"), dpi=600)
+# end
+
+# plt.close("all")
+# end
+
+
+
+
+
+
+# ---------------------------------------------------------------------------------------------------------------------------------------------------------------------
+# ------ Using StatsPlots
+# ---------------------------------------------------------------------------------------------------------------------------------------------------------------------
+# function plotAUPR(gsParam::OrderedDict{String,<:AbstractDict}, dirOut::String,
+# saveName::String, plotType::String = "bar";
+# figSize::Tuple{Real,Real} = (8,5), axisTitleSize::Int = 13,
+# tickLabelSize::Int = 10)
+
+# legendFontSize = 9
+
+# # Check if dirOut is either nothing or an empty string
+# dirOut = if dirOut === nothing || isempty(dirOut)
+# pwd()
+# else
+# mkpath(dirOut)
+# end
+
+# # Dummy load function – replace with your actual data-loading mechanism.
+# function loadAupr(filePath)
+# try
+# data = load(filePath)
+# return data["results"][:auprs]
+# catch e
+# println("Error loading $filePath: $e")
+# return nothing
+# end
+# end
+
+# # Initial color palette
+# lineColors = ["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628",
+# "#984ea3", "#e41a1c", "#00ced1", "#000000",
+# "#5A9D5A", "#D96D3B", "#FFAD12", "#66628D", "#91569A",
+# "#B6742A", "#DD87B4", "#D26D7A", "#dede00"]
+
+# # Build a DataFrame with columns: xGroups, Network, AUPR
+# gsNames = collect(keys(gsParam))
+# netNames = unique(vcat([collect(keys(files)) for files in values(gsParam)]...))
+# if length(netNames) > length(lineColors)
+# lineColors = padColors(netNames) # Ensure padColors is defined!
+# end
+
+# dfRows = Vector{NamedTuple{(:xGroups, :Network, :AUPR), Tuple{String,String,Float64}}}()
+# for gs in gsNames
+# for net in netNames
+# filePath = haskey(gsParam[gs], net) ? gsParam[gs][net] : nothing
+# temp = filePath === nothing ? nothing : loadAupr(filePath)
+# val = (temp === nothing || temp === missing) ? 0.0 : temp
+# push!(dfRows, (xGroups = gs, Network = net, AUPR = val))
+# end
+# end
+# df = DataFrame(dfRows)
+
+# # Number of unique groups.
+# numGS = length(unique(df.xGroups))
+# numNet = length(unique(df.Network))
+
+# # Set common labels and legend; use camelCase.
+# xLabelVal = (numGS == 1 ? "Network" : "Gold Standards")
+# yLabelVal = "AUPR"
+# legendPos = :outerright
+
+# # Common arguments (for scatter and bar, without color arguments).
+# commonArgs = (xlabel = xLabelVal, ylabel = yLabelVal, legend = legendPos,
+# guidefontsize = axisTitleSize, tickfontsize = tickLabelSize, legendfontsize = legendFontSize,
+# framestyle = :box)
+
+# colors = lineColors[1:numNet]
+
+# p = nothing
+# if lowercase(plotType) == "scatter"
+# if numGS == 1
+# p = @df df sp.scatter(:Network, :AUPR; markersize = 6,
+# color = colors, xrotation = 45, commonArgs...)
+# elseif numNet == 1
+# p = @df df sp.scatter(:xGroups, :AUPR; markersize = 6,
+# color = lineColors[1], xrotation = 45, commonArgs...)
+# else
+# p = @df df sp.scatter(:xGroups, :AUPR; group = :Network, markersize = 6,
+# palette = colors, commonArgs...)
+# end
+
+# elseif lowercase(plotType) == "bar"
+# if numGS == 1
+# p = @df df sp.bar(:Network, :AUPR;
+# color = colors, xrotation = 45, commonArgs...)
+# elseif numNet == 1
+# p = @df df sp.bar(:xGroups, :AUPR;
+# color = lineColors[1], xrotation = 45, commonArgs...)
+# else
+# p = @df df sp.groupedbar(:xGroups, :AUPR; group = :Network,
+# palette = colors, commonArgs...)
+# end
+# else
+# error("Invalid plot type. Please choose 'scatter' or 'bar'")
+# end
+
+# # Adjust figure size (e.g. figSize = (6,3.5) converts to (600,350) pixels)
+# p = sp.plot(p, size = (Int(figSize[1]*100), Int(figSize[2]*100)), titlefontsize = axisTitleSize)
+
+
+# if !isempty(saveName)
+# savePath = joinpath(dirOut, saveName * "_AUPR_$(plotType).pdf")
+# sp.savefig(p, savePath)
+# else
+# dateStr = Dates.format(now(), "yyyymmdd_HHMMSS")
+# savePath = joinpath(dirOut, dateStr * "_AUPR_$(plotType).pdf")
+# sp.savefig(p, savePath)
+# end
+# end
+
+
+
diff --git a/src/metrics/plotting/PlotSingle.jl b/src/metrics/plotting/PlotSingle.jl
new file mode 100755
index 0000000..a202236
--- /dev/null
+++ b/src/metrics/plotting/PlotSingle.jl
@@ -0,0 +1,94 @@
+"""
+ plotPRCurve(rec, prec, randPR, saveDir; kwargs...)
+
+Plot and save a Precision-Recall (PR) curve using PyPlot.
+
+# Arguments
+- `rec::AbstractVector` : Recall values.
+- `prec::AbstractVector` : Precision values.
+- `randPR::Float64` : Random baseline precision (horizontal reference line).
+- `saveDir::String` : Directory to save the plot.
+
+# Keyword Arguments
+- `xLimitRecall::Float64=1.0` : Maximum recall shown on the x-axis of diagnostic PR plots.
+- `axisTitleSize::Int=16` : Font size for axis titles.
+- `tickLabelSize::Int=14` : Font size for tick labels.
+- `plotTitleSize::Int=18` : Font size for plot title.
+- `baseName=nothing` : Base filename. If `nothing`, a timestamp is used.
+"""
+function plotPRCurve(
+ rec::AbstractVector, prec::AbstractVector, randPR::Float64, saveDir::String;
+ xLimitRecall::Float64 = 1.0, axisTitleSize::Int = 16,
+ tickLabelSize::Int = 14, plotTitleSize::Int = 18, baseName = nothing)
+
+ if baseName === nothing
+ baseName = Dates.format(now(), "yyyymmdd_HHMMSS")
+ end
+
+ if !isempty(rec) && !isempty(prec) && any(prec .> 0)
+ PyPlot.figure()
+ PyPlot.axhline(y=randPR, linestyle="-.", color="k")
+ PyPlot.plot(rec, prec, color="b")
+ PyPlot.xlabel("Recall", fontsize=axisTitleSize)
+ PyPlot.ylabel("Precision", fontsize=axisTitleSize)
+ PyPlot.title("PR Curve: $baseName", fontsize=plotTitleSize)
+ PyPlot.xlim(0, xLimitRecall)
+ PyPlot.ylim(0, 1)
+ PyPlot.grid(true, which="major", linestyle="--", linewidth=0.75, color="gray")
+ PyPlot.minorticks_on()
+ PyPlot.grid(true, which="minor", linestyle=":", linewidth=0.5, color="lightgray")
+ PyPlot.tick_params(axis="both", which="major", labelsize=tickLabelSize)
+ PyPlot.tick_params(axis="both", which="minor", labelsize=tickLabelSize - 2)
+ PyPlot.savefig(joinpath(saveDir, baseName * "_PR.png"), dpi=600)
+ PyPlot.close()
+ else
+ @warn "Skipping PR plot for $baseName: empty or all-zero precision/recall vectors"
+ end
+end
+
+
+"""
+ plotROCCurve(fpr, tpr, saveDir; kwargs...)
+
+Plot and save a ROC curve using PyPlot.
+
+# Arguments
+- `fpr::AbstractVector` : False positive rate values (x-axis).
+- `tpr::AbstractVector` : True positive rate / recall values (y-axis).
+- `saveDir::String` : Directory to save the plot.
+
+# Keyword Arguments
+- `axisTitleSize::Int=16` : Font size for axis titles.
+- `tickLabelSize::Int=14` : Font size for tick labels.
+- `plotTitleSize::Int=18` : Font size for plot title.
+- `baseName=nothing` : Base filename. If `nothing`, a timestamp is used.
+"""
+function plotROCCurve(
+ fpr::AbstractVector, tpr::AbstractVector, saveDir::String;
+ axisTitleSize::Int = 16, tickLabelSize::Int = 14,
+ plotTitleSize::Int = 18, baseName = nothing)
+
+ if baseName === nothing
+ baseName = Dates.format(now(), "yyyymmdd_HHMMSS")
+ end
+
+ if !isempty(fpr) && !isempty(tpr) && (maximum(tpr) > 0 || maximum(fpr) > 0)
+ PyPlot.figure()
+ PyPlot.plot([0, 1], [0, 1], linestyle="--", color="k")
+ PyPlot.plot(fpr, tpr, color="b")
+ PyPlot.xlabel("FPR", fontsize=axisTitleSize)
+ PyPlot.ylabel("TPR", fontsize=axisTitleSize)
+ PyPlot.title("ROC Curve: $baseName", fontsize=plotTitleSize)
+ PyPlot.xlim(0, 1)
+ PyPlot.ylim(0, 1)
+ PyPlot.grid(true, which="major", linestyle="--", linewidth=0.75, color="gray")
+ PyPlot.minorticks_on()
+ PyPlot.grid(true, which="minor", linestyle=":", linewidth=0.5, color="lightgray")
+ PyPlot.tick_params(axis="both", which="major", labelsize=tickLabelSize)
+ PyPlot.tick_params(axis="both", which="minor", labelsize=tickLabelSize - 2)
+ PyPlot.savefig(joinpath(saveDir, baseName * "_ROC.png"), dpi=600)
+ PyPlot.close()
+ else
+ @warn "Skipping ROC plot for $baseName: empty or all-zero TPR/FPR vectors"
+ end
+end
\ No newline at end of file
diff --git a/src/prior/MergeDegenerateTFs.jl b/src/prior/MergeDegenerateTFs.jl
new file mode 100755
index 0000000..671cb7a
--- /dev/null
+++ b/src/prior/MergeDegenerateTFs.jl
@@ -0,0 +1,366 @@
+ using CSV
+ using DataFrames
+ using Printf
+ using FileIO
+ using Base.Iterators: partition
+ using Base.Threads
+
+ """
+ This script merges transcription factors (TFs) with identical target sets and interaction weights
+ into meta-TFs, producing a merged prior network and summary tables.
+
+ # Inputs
+ - `networkFile::String` — Path to the input network file.
+ - `outFileBase::Union{String, Nothing}` — Base path for output files. If `nothing`, defaults to the input file's base name.
+ - `fileFormat::Int` — Format of the input network:
+ - `1` = Long format (TF, Target, Weight)
+ - `2` = Wide format (Targets × TFs matrix)
+ - `connector::String` (default = "_") — String used to join TF names into merged meta-TF names.
+
+ # Output Files
+ 1. `*_merged_sp.tsv` — Long-format network with merged TF names.
+ 2. `*_merged.tsv` — Wide-format matrix of the merged network (targets × regulators).
+ 3. `*_overlaps.tsv` — TF-by-TF overlap matrix showing shared targets.
+ 4. `*_targetTotals.tsv` — Number of targets per (possibly merged) TF.
+ 5. `*_mergedTfs.tsv` — Mappings from each merged meta-TF to its original TF members.
+
+ # Notes
+ - Run Julia with multiple threads (e.g., `julia --threads=6`) to enable parallel processing.
+ - Ensure that TF names do not already include the `connector` string to avoid unintended merges.
+ - Zeros are removed in the long-format representation before merging.
+ - The returned merged network is written in wide format and can be used for downstream analysis.
+
+ # Returns
+ A `NamedTuple` with:
+ - `merged::DataFrame` — Wide-format merged matrix (targets x meta-TFs).
+ - `mergedTfs::Matrix{String}` — A two-column matrix where each row contains a meta-TF name and a comma-separated list of its constituent TFs.
+ """
+
+ # Struct defined in src/Types.jl
+
+
+ function countOverlap(s1::AbstractSet, s2::AbstractSet)
+ """
+ Count the number of shared elements between two sets.
+
+ Returns the size of the intersection without materializing a new set.
+ """
+ # Iterate over the smaller set
+ if length(s1) > length(s2)
+ s1, s2 = s2, s1
+ end
+ c = 0
+ for t in s1
+ if t in s2
+ c += 1
+ end
+ end
+ return c
+ end
+
+ function readNetwork(networkFile::String; fileFormat::Int)
+ """
+ Parse a regulatory network file in either long or wide format.
+
+ # Arguments
+ - `networkFile::String`: Path to the network file.
+ - `fileFormat::Int`:
+ - `1` = long format (TF, Target, Weight)
+ - `2` = wide format (Targets × TFs matrix)
+
+ # Returns
+ - A dictionary mapping TF names to sets of (target, weight) tuples.
+
+ # Notes
+ - Entries with weight zero are ignored.
+ - Weights are stored as strings (converted to Float64 later).
+ """
+ # Initialize dictionary to store TF-target interactions
+ tfTargDic = Dict{String, Set{Tuple{String, String}}}()
+
+ # Parse the network file based on format
+ if fileFormat == 1
+ # Long format (direct TF-target-weight representation)
+ # Long format: each line has the form: TFTargetWeight.
+ open(networkFile, "r") do file
+ readline(file) # skip header
+ for line in eachline(file)
+ line = strip(line)
+ if isempty(line)
+ continue
+ end
+ parts = split(line, '\t')
+ if length(parts) < 3
+ continue
+ end
+ tfName, target, weight = parts[1], parts[2], parts[3]
+ if weight != "0"
+ if !haskey(tfTargDic, tfName)
+ tfTargDic[tfName] = Set{Tuple{String, String}}()
+ end
+ push!(tfTargDic[tfName], (target, weight))
+ end
+ end
+ println("Dict Created!!!")
+ end
+
+ elseif fileFormat == 2
+ # Wide format (TFs as columns, Targets as rows, and cells as interactions)
+ df = CSV.File(networkFile, header=1; delim='\t') |> DataFrame
+ melted = stack(df, Not(1))
+ melted = rename!(melted, names(melted)[1] => :Target, names(melted)[2] => :TF, names(melted)[3] => :Weight)
+ melted = select!(melted, :TF, :Target, :Weight)
+ melted = filter(row -> row.Weight != 0, melted)
+ for row in eachrow(melted)
+ tfName = row.TF
+ target = row.Target
+ weight = row.Weight
+ if !(tfName in keys(tfTargDic))
+ tfTargDic[tfName] = Set()
+ end
+ push!(tfTargDic[tfName], (target, string(weight)))
+ end
+ end
+ return tfTargDic
+ end
+
+
+ function groupRedundantTFs(tfTargDic::Dict{String, Set{Tuple{String, String}}})
+ """
+ Compute overlapping TFs with identical target–weight sets and group them into meta-TFs.
+
+ # Arguments
+ - `tfTargDic::Dict{String, Set{Tuple{String, String}}}`: Mapping from TFs to their target–weight pairs.
+
+ # Returns
+ - `tfMergers::Dict{String, Vector{String}}`: TFs grouped into merge sets.
+ - `overlaps::Dict{Tuple{String, String}, Int}`: Pairwise TF overlap counts.
+ - `tfTargNums::Dict{String, Int}`: Number of targets per TF.
+ - `tfNames::Vector{String}`: All TF names (sorted).
+
+ # Notes
+ - Uses threading if multiple threads are available.
+ - TFs are merged if they have exactly the same set of target–weight pairs.
+ """
+ # Check if more than 1 thread is available
+ nthreads = Threads.nthreads()
+ use_threads = nthreads > 1
+
+ # Step A: Determine TF overlaps
+ tfNames = sort(collect(keys(tfTargDic)))
+ tfMergers = Dict{String, Vector{String}}()
+ overlaps = Dict{Tuple{String, String}, Int}()
+ tfTargNums = Dict{String, Int}()
+
+ # Fill tfTargNums and diagonal of overlap matrix
+ for (tf, targets) in tfTargDic
+ n = length(targets)
+ tfTargNums[tf] = n
+ overlaps[(tf,tf)] = n #seed the diagonal of the overlaps matrix
+ end
+
+ if use_threads
+ localOverlapsArray = [Dict{Tuple{String, String}, Int}() for i in 1:nthreads]
+ localTFMergersArray = [Dict{String, Set{String}}() for i in 1:nthreads]
+
+ println("Running in threaded mode ...")
+ Threads.@threads for i in 1:length(tfNames)
+ # Use the thread's id (an integer between 1 and nthreads) to access its local storage.
+ tid = threadid()
+ localOverlaps = localOverlapsArray[tid]
+ localMergers = localTFMergersArray[tid]
+
+ tf1 = tfNames[i]
+ tf1targets = tfTargDic[tf1]
+ for j in (i+1):length(tfNames)
+ tf2 = tfNames[j]
+ tf2targets = tfTargDic[tf2]
+ overlapSize = countOverlap(tf1targets, tf2targets)
+
+ # Save both orderings if needed.
+ localOverlaps[(tf1, tf2)] = overlapSize
+ localOverlaps[(tf2, tf1)] = overlapSize
+
+ # If the overlap equals each TF's target count then they fully overlap.
+ if tfTargNums[tf1] == overlapSize && tfTargNums[tf2] == overlapSize
+ # Update local mergers for tf1.
+ localMergers[tf1] = get!(localMergers, tf1, Set([tf1]))
+ push!(localMergers[tf1], tf2)
+
+ # Update local mergers for tf2.
+ localMergers[tf2] = get!(localMergers, tf2, Set([tf2]))
+ push!(localMergers[tf2], tf1)
+ end
+ end
+ end
+
+ # Merge the thread-local results into global dictionaries.
+ for lo in localOverlapsArray
+ for (k,v) in lo
+ overlaps[k] = v
+ end
+ end
+
+ for localx in localTFMergersArray
+ for (tf, mergerSet) in localx
+ if haskey(tfMergers, tf)
+ # Combine the sets, then convert back to vector making sure they're unique.
+ unionSet = union(Set(tfMergers[tf]), mergerSet)
+ tfMergers[tf] = collect(unionSet)
+ else
+ tfMergers[tf] = collect(mergerSet)
+ end
+ end
+ end
+
+ else
+ println("Running in sequential mode ...")
+ for i in 1:length(tfNames)
+
+ tf1 = tfNames[i]
+ tf1targets = tfTargDic[tf1]
+
+ for j in i+1:length(tfNames)
+ tf2 = tfNames[j]
+ tf2targets = tfTargDic[tf2]
+
+ overlapSize = countOverlap(tf1targets, tf2targets)
+ overlaps[(tf1, tf2)] = overlapSize
+ overlaps[(tf2, tf1)] = overlapSize
+
+ if tfTargNums[tf1] == tfTargNums[tf2] == overlapSize
+ # Use get! to set default values and append
+ push!(get!(tfMergers, tf1, [tf1]), tf2)
+ push!(get!(tfMergers, tf2, [tf2]), tf1)
+ end
+ end
+ end
+ end
+ println("TF Overlaps Determination Complete!!")
+
+ return tfMergers, overlaps, tfTargNums, tfNames
+ end
+
+
+ # Function to merge degenerate prior TFs
+ function mergeDegenerateTFs(
+ mergedTFsData::mergedTFsResult,
+ networkFile::String;
+ outFileBase::Union{String,Nothing}=nothing,
+ fileFormat::Int = 2,
+ connector::String = "_"
+ )
+ """
+ merge_degenerate_priors(
+ networkFile::String;
+ outFileBase::Union{String,Nothing}=nothing,
+ fileFormat::Int=1,
+ connector::String="_",
+ write_files::Bool=true
+ )
+ -- NamedTuple{(:merged, :mergedTfs),Tuple{DataFrame,Vector{String}}}
+
+ - networkFile – path to your priorFile
+ - outFileBase – optional "base path+stem" for writing the output files;
+ if `nothing` we auto‐derive it from `networkFile`.
+ - fileFormat – 1=long, 2=wide
+ - connector – string between merged TF names, default "_"
+ - write_files – if true (default) write the five output files;
+ if false, run purely in‐memory.
+
+ Returns a NamedTuple
+ - merged = wide‐format DataFrame (targets×regulators)
+ - mergedTfs = Vector{String} of all the new meta‐TF names
+ """
+
+ # 1. ----- Read + compute mergers, overlaps, counts, names...
+ tfTargDic = readNetwork(networkFile; fileFormat)
+ tfMergers, overlaps, tfTargNums, tfNames = groupRedundantTFs(tfTargDic)
+
+ # 2. ----- Write output files
+
+ # If no outFileBase was supplied, derive it from networkFile:
+ stem = splitext(basename(networkFile))[1] # filename without extension
+ if outFileBase === nothing
+ dir = dirname(networkFile)
+ outFileBase = joinpath(dir, stem)
+ else
+ outFileBase = joinpath(outFileBase,stem)
+ end
+ netOutFile = outFileBase * "_merged_sp.tsv"
+ netMatOutFile = outFileBase * "_merged.tsv"
+ overlapsOutFile = outFileBase * "_overlaps.tsv"
+ totTargOutFile = outFileBase * "_targetTotals.tsv"
+ mergedTfsOutFile = outFileBase * "_mergedTFs.tsv"
+
+ # Open all
+ netIO = open(netOutFile, "w")
+ overlapsIO = open(overlapsOutFile, "w")
+ totTargIO = open(totTargOutFile, "w")
+ mergedTfsIO = open(mergedTfsOutFile, "w")
+
+ # In-memory collector
+ netDF = DataFrame(Regulator=String[], Target=String[], Weight=String[])
+ tabMergedTFs = Vector{Vector{String}}() # will hold lines "metaTF\tmember1, member2,…"
+ allMergedTfs = collect(keys(tfMergers))
+ usedMergedTfs = Set{String}() # keeps track of used TFs, so we don't output mergers twice
+ printedTfs = String[]
+
+ try
+ # Write header to network output file.
+ println(netIO, "Regulator\tTarget\tWeight")
+ for tf in tfNames
+ tfPrint = nothing
+ doPrint = false
+ if tf in allMergedTfs && !(tf in usedMergedTfs)
+ mergedTfs = sort(collect(tfMergers[tf]))
+ union!(usedMergedTfs, mergedTfs)
+ tfPrint = join(mergedTfs[1:2], connector) * (length(mergedTfs) > 2 ? "..." : "")
+ # 1) write merged‐TF mapping to disk
+ line = tfPrint * "\t" * join(mergedTfs, ", ")
+ println(mergedTfsIO, line)
+ # 2) store in memory
+ push!(tabMergedTFs, [tfPrint, join(mergedTfs, ", ")])
+ doPrint = true
+
+ elseif !(tf in allMergedTfs)
+ tfPrint = tf
+ doPrint = true
+ end
+
+ if doPrint
+ # Write target totals.
+ println(totTargIO, "$(tfPrint)\t$(tfTargNums[tf])")
+ # Write each TF-target-weight record to a line.
+ for (targ, wgt) in tfTargDic[tf]
+ outline = "$(tfPrint)\t$(targ)\t$(wgt)"
+ println(netIO, outline)
+ # also push into netDF
+ push!(netDF, (Regulator=tfPrint, Target=targ, Weight=wgt))
+ end
+ push!(printedTfs, tfPrint)
+ end
+ end
+
+ println(overlapsIO, "\t" * join(printedTfs, "\t"))
+ for tf in printedTfs
+ row = [ string(overlaps[(first(split(tf,connector)), first(split(pt,connector)))])
+ for pt in printedTfs ]
+ println(overlapsIO, tf * '\t' * join(row, '\t'))
+ end
+ println("Overlap Output File Successfully Written!!!")
+ finally
+ close(netIO); close(overlapsIO); close(totTargIO); close(mergedTfsIO)
+ end
+
+ netDF.Weight = parse.(Float64, netDF.Weight)
+ mergedPrior = convertToWide(netDF; indices=(2, 1, 3))
+ mergedPrior .= coalesce.(mergedPrior, 0.0)
+ # write to file, while makig sure the first column is unnamed.
+ writeTSVWithEmptyFirstHeader(mergedPrior, netMatOutFile; delim ='\t')
+
+ println("Output files:\n$mergedTfsIO\n$totTargOutFile\n$netOutFile\n$overlapsOutFile")
+ mergedTFsData.mergedPrior = mergedPrior
+ mergedTFsData.mergedTFMap = reduce(vcat, permutedims.(tabMergedTFs)) # Convert tabMergedTFs to a two columns matrix and then saves
+
+ end
diff --git a/test/runtests.jl b/test/runtests.jl
new file mode 100755
index 0000000..55314bf
--- /dev/null
+++ b/test/runtests.jl
@@ -0,0 +1,497 @@
+# test/runtests.jl
+using Test
+using InferelatorJL
+using Random
+using Statistics
+using LinearAlgebra
+
+# =============================================================================
+# Helpers — write synthetic input files to a temp directory
+# =============================================================================
+
+"""
+Write a gene expression TSV:
+ - Row 1: empty cell + sample names
+ - Rows 2+: gene name + Float64 values
+"""
+function write_expr_tsv(path, geneNames, sampleNames, mat)
+ open(path, "w") do io
+ println(io, "\t" * join(sampleNames, "\t"))
+ for (i, g) in enumerate(geneNames)
+ println(io, g * "\t" * join(string.(mat[i, :]), "\t"))
+ end
+ end
+end
+
+"""
+Write a gene list file (one name per line).
+"""
+write_gene_list(path, names) = write(path, join(names, "\n") * "\n")
+
+"""
+Write a sparse prior TSV:
+ - Row 1: empty cell + TF names
+ - Rows 2+: gene name + Float64 values
+"""
+function write_prior_tsv(path, tfNames, geneNames, mat)
+ open(path, "w") do io
+ println(io, "\t" * join(tfNames, "\t"))
+ for (i, g) in enumerate(geneNames)
+ println(io, g * "\t" * join(string.(mat[i, :]), "\t"))
+ end
+ end
+end
+
+
+# =============================================================================
+@testset "InferelatorJL" begin
+# =============================================================================
+
+
+# -----------------------------------------------------------------------------
+@testset "Structs instantiate" begin
+# -----------------------------------------------------------------------------
+ @test GeneExpressionData() isa GeneExpressionData
+ @test PriorTFAData() isa PriorTFAData
+ @test mergedTFsResult() isa mergedTFsResult
+ @test GrnData() isa GrnData
+ @test BuildGrn() isa BuildGrn
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "Data loading — expression matrix" begin
+# -----------------------------------------------------------------------------
+ tmpdir = mktempdir()
+
+ # Genes intentionally out of alphabetical order to test sort
+ genes = ["Zap70", "Akt1", "Myc", "Foxp3"]
+ samples = ["S1", "S2", "S3", "S4", "S5"]
+ mat = Float64[
+ 1.0 2.0 3.0 4.0 5.0; # Zap70
+ 6.0 7.0 8.0 9.0 10.0; # Akt1
+ 0.1 0.2 0.1 0.2 0.1; # Myc
+ 3.0 3.5 4.0 4.5 5.0; # Foxp3
+ ]
+ exprFile = joinpath(tmpdir, "expr.txt")
+ write_expr_tsv(exprFile, genes, samples, mat)
+
+ data = GeneExpressionData()
+ InferelatorJL.loadExpressionData!(data, exprFile)
+
+ # Gene names should be sorted alphabetically
+ @test data.geneNames == sort(genes)
+
+ # Matrix rows must match the sorted order
+ sorted_order = sortperm(genes) # [2,4,3,1] → Akt1,Foxp3,Myc,Zap70
+ @test data.geneExpressionMat == mat[sorted_order, :]
+
+ # Sample labels are parsed from the header
+ @test data.cellLabels == samples
+
+ # Dimensions: genes × samples
+ @test size(data.geneExpressionMat) == (length(genes), length(samples))
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "Data loading — non-numeric first column values are gene names" begin
+# -----------------------------------------------------------------------------
+ # Verify that a gene name like "1500009L16Rik" (starts with a digit) is
+ # treated as a gene name, not discarded or confused with a header.
+ tmpdir = mktempdir()
+ genes = ["1500009L16Rik", "Bcl2", "Cd44"]
+ samples = ["A", "B", "C"]
+ mat = [1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0]
+
+ exprFile = joinpath(tmpdir, "expr_numnames.txt")
+ write_expr_tsv(exprFile, genes, samples, mat)
+
+ data = GeneExpressionData()
+ InferelatorJL.loadExpressionData!(data, exprFile)
+
+ @test "1500009L16Rik" in data.geneNames
+ @test length(data.geneNames) == 3
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "Data loading — invalid file path throws" begin
+# -----------------------------------------------------------------------------
+ data = GeneExpressionData()
+ @test_throws ErrorException InferelatorJL.loadExpressionData!(data, "/nonexistent/path/expr.txt")
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "Data loading — target gene filtering" begin
+# -----------------------------------------------------------------------------
+ tmpdir = mktempdir()
+ genes = ["Akt1", "Bcl2", "Myc", "Foxp3", "Stat3"]
+ samples = ["S1", "S2", "S3", "S4", "S5", "S6"]
+
+ # Myc has zero variance — should be removed with default epsilon
+ mat = Float64[
+ 1.0 2.0 3.0 4.0 5.0 6.0; # Akt1 — good variance
+ 2.0 2.5 3.0 2.5 2.0 3.5; # Bcl2 — good variance
+ 1.0 1.0 1.0 1.0 1.0 1.0; # Myc — ZERO variance
+ 0.5 1.5 2.5 3.5 4.5 5.5; # Foxp3 — good variance
+ 3.0 1.0 4.0 1.0 5.0 9.0; # Stat3 — good variance
+ ]
+ exprFile = joinpath(tmpdir, "expr.txt")
+ write_expr_tsv(exprFile, genes, samples, mat)
+
+ # Request 4 genes (exclude Stat3); Myc will be removed by variance filter
+ targFile = joinpath(tmpdir, "targs.txt")
+ write_gene_list(targFile, ["Akt1", "Bcl2", "Myc", "Foxp3"])
+
+ data = GeneExpressionData()
+ InferelatorJL.loadExpressionData!(data, exprFile)
+ InferelatorJL.loadAndFilterTargetGenes!(data, targFile; epsilon = 0.01)
+
+ # Myc must have been removed
+ @test !("Myc" in data.targGenes)
+
+ # The other three requested genes should be present
+ @test "Akt1" in data.targGenes
+ @test "Bcl2" in data.targGenes
+ @test "Foxp3" in data.targGenes
+
+ # targGeneMat dimensions must match retained genes × samples
+ @test size(data.targGeneMat) == (3, length(samples))
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "Data loading — target file not found throws" begin
+# -----------------------------------------------------------------------------
+ tmpdir = mktempdir()
+ genes = ["Akt1", "Bcl2"]
+ samples = ["S1", "S2", "S3"]
+ mat = [1.0 2.0 3.0; 4.0 5.0 6.0]
+ exprFile = joinpath(tmpdir, "expr.txt")
+ write_expr_tsv(exprFile, genes, samples, mat)
+
+ data = GeneExpressionData()
+ InferelatorJL.loadExpressionData!(data, exprFile)
+ @test_throws ErrorException InferelatorJL.loadAndFilterTargetGenes!(
+ data, "/no/such/file.txt")
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "Data loading — potential regulators" begin
+# -----------------------------------------------------------------------------
+ tmpdir = mktempdir()
+ genes = ["Akt1", "Bcl2", "Foxp3", "Myc"]
+ samples = ["S1", "S2", "S3"]
+ mat = [1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0; 10.0 11.0 12.0]
+ exprFile = joinpath(tmpdir, "expr.txt")
+ write_expr_tsv(exprFile, genes, samples, mat)
+
+ # Regulator list: 2 present in expression, 1 not present
+ regFile = joinpath(tmpdir, "regs.txt")
+ write_gene_list(regFile, ["Foxp3", "Akt1", "NotInExpr"])
+
+ data = GeneExpressionData()
+ InferelatorJL.loadExpressionData!(data, exprFile)
+ InferelatorJL.loadPotentialRegulators!(data, regFile)
+
+ # Only the 2 regulators present in expression data should appear
+ @test length(data.potRegs) == 2
+ @test "Foxp3" in data.potRegs
+ @test "Akt1" in data.potRegs
+ @test !("NotInExpr" in data.potRegs)
+
+ # mRNA matrix: rows = regulators that have expression, cols = samples
+ @test size(data.potRegMatmRNA) == (2, length(samples))
+end
+
+
+# =============================================================================
+# Helper — build a fully loaded GeneExpressionData from scratch in memory
+# =============================================================================
+function make_test_data(tmpdir)
+ genes = ["Akt1", "Bcl2", "Foxp3", "Myc", "Stat3"]
+ tfs = ["Foxp3", "Stat3"] # regulators
+ samples = ["S$i" for i in 1:10]
+
+ Random.seed!(42)
+ mat = rand(Float64, length(genes), length(samples)) .* 5 .+ 0.5
+
+ exprFile = joinpath(tmpdir, "expr.txt")
+ targFile = joinpath(tmpdir, "targs.txt")
+ regFile = joinpath(tmpdir, "regs.txt")
+
+ write_expr_tsv(exprFile, genes, samples, mat)
+ write_gene_list(targFile, genes) # all genes are targets
+ write_gene_list(regFile, tfs)
+
+ data = GeneExpressionData()
+ InferelatorJL.loadExpressionData!(data, exprFile)
+ InferelatorJL.loadAndFilterTargetGenes!(data, targFile; epsilon = 0.01)
+ InferelatorJL.loadPotentialRegulators!(data, regFile)
+ InferelatorJL.processTFAGenes!(data, "") # use all genes for TFA
+
+ return data, genes, tfs, samples
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "TFA — output dimensions" begin
+# -----------------------------------------------------------------------------
+ tmpdir = mktempdir()
+ data, genes, tfs, samples = make_test_data(tmpdir)
+
+ nTFs = length(tfs)
+ nGenes = length(genes)
+ nSamples = length(samples)
+
+ # Prior: every gene has at least one edge per TF
+ # Shape: genes × tfs (written as rows=genes, cols=tfs)
+ prior_mat = Float64[
+ 1 0; # Akt1 → Foxp3
+ 1 1; # Bcl2 → Foxp3, Stat3
+ 0 1; # Foxp3 → Stat3
+ 1 1; # Myc → both
+ 1 1; # Stat3 → both
+ ]
+ priorFile = joinpath(tmpdir, "prior.tsv")
+ write_prior_tsv(priorFile, tfs, sort(genes), prior_mat)
+
+ tfaData = PriorTFAData()
+ InferelatorJL.processPriorFile!(tfaData, data, priorFile;
+ mergedTFsData = nothing,
+ minTargets = 1)
+ InferelatorJL.calculateTFA!(tfaData, data; edgeSS = 0, zTarget = false)
+
+ # medTfas must be nTFs × nSamples (result of priorMatrix \ targExpression)
+ @test size(tfaData.medTfas, 1) == length(tfaData.pRegs)
+ @test size(tfaData.medTfas, 2) == nSamples
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "TFA — least-squares correctness" begin
+# -----------------------------------------------------------------------------
+ # With a square, well-conditioned prior, TFA = prior \ expression exactly.
+ tmpdir = mktempdir()
+
+ genes = ["Akt1", "Bcl2", "Foxp3"]
+ tfs = ["Foxp3", "Akt1"]
+ samples = ["S1", "S2", "S3", "S4"]
+
+ Random.seed!(7)
+ expr_mat = rand(Float64, length(genes), length(samples)) .+ 1.0
+
+ exprFile = joinpath(tmpdir, "expr.txt")
+ targFile = joinpath(tmpdir, "targs.txt")
+ regFile = joinpath(tmpdir, "regs.txt")
+ write_expr_tsv(exprFile, genes, samples, expr_mat)
+ write_gene_list(targFile, genes)
+ write_gene_list(regFile, tfs)
+
+ # Full prior: all 3 genes × 2 TFs, every cell nonzero (both TFs have >1 target)
+ prior_mat = Float64[
+ 1.0 0.5;
+ 0.5 1.0;
+ 0.8 0.3;
+ ]
+ priorFile = joinpath(tmpdir, "prior.tsv")
+ write_prior_tsv(priorFile, tfs, genes, prior_mat)
+
+ data = GeneExpressionData()
+ InferelatorJL.loadExpressionData!(data, exprFile)
+ InferelatorJL.loadAndFilterTargetGenes!(data, targFile; epsilon = 0.01)
+ InferelatorJL.loadPotentialRegulators!(data, regFile)
+ InferelatorJL.processTFAGenes!(data, "")
+
+ tfaData = PriorTFAData()
+ InferelatorJL.processPriorFile!(tfaData, data, priorFile;
+ mergedTFsData = nothing, minTargets = 1)
+ InferelatorJL.calculateTFA!(tfaData, data; edgeSS = 0, zTarget = false)
+
+ # Manually compute expected TFA: priorMatrix \ targExpression
+ expected = tfaData.priorMatrix \ tfaData.targExpression
+ @test tfaData.medTfas ≈ expected atol=1e-10
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "TFA — z-scored targets change output" begin
+# -----------------------------------------------------------------------------
+ tmpdir = mktempdir()
+ data, genes, tfs, samples = make_test_data(tmpdir)
+
+ prior_mat = Float64[
+ 1 0; 1 1; 0 1; 1 1; 1 1;
+ ]
+ priorFile = joinpath(tmpdir, "prior.tsv")
+ write_prior_tsv(priorFile, tfs, sort(genes), prior_mat)
+
+ tfaData_raw = PriorTFAData()
+ InferelatorJL.processPriorFile!(tfaData_raw, data, priorFile;
+ mergedTFsData = nothing, minTargets = 1)
+ InferelatorJL.calculateTFA!(tfaData_raw, data; edgeSS = 0, zTarget = false)
+
+ tfaData_z = PriorTFAData()
+ InferelatorJL.processPriorFile!(tfaData_z, data, priorFile;
+ mergedTFsData = nothing, minTargets = 1)
+ InferelatorJL.calculateTFA!(tfaData_z, data; edgeSS = 0, zTarget = true)
+
+ # z-scored and non-z-scored TFA should differ
+ @test !isapprox(tfaData_raw.medTfas, tfaData_z.medTfas; atol=1e-6)
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "Prior — minTargets filter removes low-coverage TFs" begin
+# -----------------------------------------------------------------------------
+ tmpdir = mktempdir()
+ data, genes, tfs, samples = make_test_data(tmpdir)
+
+ # TF1 (Foxp3) has 4 targets; TF2 (Stat3) has only 1 target
+ prior_mat = Float64[
+ 1 0; # Akt1 → Foxp3 only
+ 1 0; # Bcl2 → Foxp3 only
+ 1 1; # Foxp3 → both
+ 1 0; # Myc → Foxp3 only
+ 0 1; # Stat3 → Stat3 only (1 target)
+ ]
+ priorFile = joinpath(tmpdir, "prior.tsv")
+ write_prior_tsv(priorFile, tfs, sort(genes), prior_mat)
+
+ tfaData = PriorTFAData()
+ # minTargets = 3: Foxp3 has 4 targets (passes), Stat3 has 1 (fails)
+ InferelatorJL.processPriorFile!(tfaData, data, priorFile;
+ mergedTFsData = nothing, minTargets = 3)
+
+ @test "Foxp3" in tfaData.pRegs
+ @test !("Stat3" in tfaData.pRegs)
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "Penalty matrix — self-regulatory edges set to Inf in TFmRNA mode" begin
+# -----------------------------------------------------------------------------
+ tmpdir = mktempdir()
+ data, genes, tfs, samples = make_test_data(tmpdir)
+
+ prior_mat = Float64[
+ 1 0; 1 1; 0 1; 1 1; 1 1;
+ ]
+ priorFile = joinpath(tmpdir, "prior.tsv")
+ write_prior_tsv(priorFile, tfs, sort(genes), prior_mat)
+
+ tfaData = PriorTFAData()
+ InferelatorJL.processPriorFile!(tfaData, data, priorFile;
+ mergedTFsData = nothing, minTargets = 1)
+ InferelatorJL.calculateTFA!(tfaData, data; edgeSS = 0, zTarget = false)
+
+ grnData = GrnData()
+ # TFmRNA mode: self-regulatory edges must be Inf
+ InferelatorJL.preparePredictorMat!(grnData, data, tfaData; tfaOpt = "TFmRNA")
+ InferelatorJL.preparePenaltyMatrix!(data, grnData;
+ priorFilePenalties = [priorFile],
+ lambdaBias = [0.5],
+ tfaOpt = "TFmRNA")
+
+ # For each TF that is also a target gene, penalty[gene, TF] must be Inf
+ for tf in data.potRegsmRNA
+ targIdx = findfirst(==(tf), data.targGenes)
+ tfIdx = findfirst(==(tf), grnData.allPredictors)
+ if targIdx !== nothing && tfIdx !== nothing
+ @test grnData.penaltyMat[targIdx, tfIdx] == Inf
+ end
+ end
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "Subsamples — shape and reproducibility" begin
+# -----------------------------------------------------------------------------
+ tmpdir = mktempdir()
+ data, genes, tfs, samples = make_test_data(tmpdir)
+
+ prior_mat = Float64[1 0; 1 1; 0 1; 1 1; 1 1]
+ priorFile = joinpath(tmpdir, "prior.tsv")
+ write_prior_tsv(priorFile, tfs, sort(genes), prior_mat)
+
+ tfaData = PriorTFAData()
+ InferelatorJL.processPriorFile!(tfaData, data, priorFile;
+ mergedTFsData = nothing, minTargets = 1)
+ InferelatorJL.calculateTFA!(tfaData, data; edgeSS = 0, zTarget = false)
+
+ grnData = GrnData()
+ InferelatorJL.preparePredictorMat!(grnData, data, tfaData; tfaOpt = "")
+
+ totSS = 10
+ ssFrac = 0.7
+ nSamples = length(data.cellLabels)
+ expectCols = floor(Int, ssFrac * nSamples)
+
+ Random.seed!(123)
+ InferelatorJL.constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = ssFrac)
+ subsamps_a = copy(grnData.subsamps)
+
+ # Correct shape
+ @test size(subsamps_a) == (totSS, expectCols)
+
+ # All indices are valid sample indices
+ @test all(1 .<= subsamps_a .<= nSamples)
+
+ # Reproducible with same seed
+ Random.seed!(123)
+ InferelatorJL.constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = ssFrac)
+ subsamps_b = copy(grnData.subsamps)
+ @test subsamps_a == subsamps_b
+
+ # Different seed → different subsamples (with overwhelming probability)
+ Random.seed!(999)
+ InferelatorJL.constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = ssFrac)
+ @test grnData.subsamps != subsamps_a
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "Edge cases — all target genes have zero variance" begin
+# -----------------------------------------------------------------------------
+ tmpdir = mktempdir()
+ genes = ["Akt1", "Bcl2"]
+ samples = ["S1", "S2", "S3"]
+ # All rows constant — every gene will fail the variance filter
+ mat = [1.0 1.0 1.0; 2.0 2.0 2.0]
+ exprFile = joinpath(tmpdir, "expr.txt")
+ targFile = joinpath(tmpdir, "targs.txt")
+ write_expr_tsv(exprFile, genes, samples, mat)
+ write_gene_list(targFile, genes)
+
+ data = GeneExpressionData()
+ InferelatorJL.loadExpressionData!(data, exprFile)
+ @test_throws ErrorException InferelatorJL.loadAndFilterTargetGenes!(
+ data, targFile; epsilon = 0.01)
+end
+
+
+# -----------------------------------------------------------------------------
+@testset "Edge cases — target gene list requests genes not in expression data" begin
+# -----------------------------------------------------------------------------
+ tmpdir = mktempdir()
+ genes = ["Akt1", "Bcl2"]
+ samples = ["S1", "S2", "S3"]
+ mat = [1.0 2.0 3.0; 4.0 5.0 6.0]
+ exprFile = joinpath(tmpdir, "expr.txt")
+ targFile = joinpath(tmpdir, "targs.txt")
+ write_expr_tsv(exprFile, genes, samples, mat)
+ write_gene_list(targFile, ["NotAGene", "AlsoNotAGene"])
+
+ data = GeneExpressionData()
+ InferelatorJL.loadExpressionData!(data, exprFile)
+ @test_throws ErrorException InferelatorJL.loadAndFilterTargetGenes!(
+ data, targFile; epsilon = 0.01)
+end
+
+
+# =============================================================================
+end # @testset "InferelatorJL"
+# =============================================================================