diff --git a/.Rhistory b/.Rhistory new file mode 100644 index 0000000..9d6d90b --- /dev/null +++ b/.Rhistory @@ -0,0 +1 @@ +# Developer Notes — InferelatorJL diff --git a/.gitignore b/.gitignore index e43b0f9..8f85e90 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,28 @@ .DS_Store + +# Manifest is machine/Julia-version specific — each environment generates its own. +# Only Project.toml is committed. Run `julia --project=. -e 'using Pkg; Pkg.instantiate()'` +# to regenerate on a new machine or Julia version. +Manifest.toml +Manifest.toml.bak + + ========================= +# Editor +# ========================= +.vscode/ +.idea/ + +# ========================= +# Logs / temp +# ========================= +*.log + +# ========================= +# Data / outputs (IMPORTANT for your project) +# ========================= +results/ +output/ +*.tsv +*.csv +*.arrow + diff --git a/DEVELOPER.md b/DEVELOPER.md new file mode 100644 index 0000000..37a6bb7 --- /dev/null +++ b/DEVELOPER.md @@ -0,0 +1,266 @@ +# Developer Notes — InferelatorJL + +Personal reference for working on this codebase. + +--- + +## Starting Julia + +```bash +julia # from any terminal +``` + +Julia has four REPL modes. You switch between them with single keystrokes: + +| Mode | Prompt | Enter with | Exit with | +|---|---|---|---| +| Normal | `julia>` | default | — | +| Package (Pkg) | `(@v1.X) pkg>` | `]` | Backspace | +| Shell | `shell>` | `;` | Backspace | +| Help | `help?>` | `?` | Backspace | + +The `@v1.X` in the Pkg prompt shows your Julia version (e.g. `@v1.12` locally, `@v1.7` on the cluster). + +--- + +## One-time setup — local Mac (Julia 1.12) + +```julia +julia> ] +(@v1.12) pkg> dev /Users/owop7y/Desktop/InferelatorJL +(@v1.12) pkg> instantiate +``` + +Then rebuild PyCall so PyPlot works (required once per Julia version): + +```julia +(@v1.12) pkg> build PyCall +``` + +- `dev` registers the package by path — Julia reads directly from your folder, no copying. +- `instantiate` resolves and downloads all dependencies, generating a local `Manifest.toml`. +- `build PyCall` compiles the Python–Julia bridge for your current Python and Julia version. + +--- + +## One-time setup — cluster (Julia 1.7.3) + +The cluster does not use `dev` mode — you work directly inside the package folder. +Run this once after cloning or pulling the repo: + +```bash +# On the cluster, inside the InferelatorJL directory: +julia --project=. -e 'using Pkg; Pkg.instantiate()' +julia --project=. -e 'using Pkg; Pkg.build("PyCall")' +``` + +This generates a fresh `Manifest.toml` on the cluster, resolved specifically for Julia 1.7.3. +The cluster Manifest is gitignored — it never conflicts with your local one. + +--- + +## Every session + +```julia +using Revise # must come BEFORE the package +using InferelatorJL # loads the package +``` + +**Always load Revise first.** It monitors `src/` for changes and patches the running +session when you save a file — no restart needed. + +--- + +## Revise: what it can and cannot do + +| Change | Revise handles it? | +|---|---| +| Edit a function body | ✅ Live — takes effect on next call | +| Add a new function | ✅ Live | +| Add a new `include(...)` to `InferelatorJL.jl` | ✅ Live | +| Add or rename a field in a struct (`Types.jl`) | ❌ Must restart Julia | +| Change a struct's field type | ❌ Must restart Julia | +| Rename or delete a struct | ❌ Must restart Julia | + +Struct changes are the only thing that forces a restart. Everything else is live. + +--- + +## Running examples + +From the REPL after loading the package: +```julia +include("examples/interactive_pipeline.jl") +include("examples/utilityExamples.jl") +``` + +From the terminal (no REPL): +```bash +julia --project=. examples/run_pipeline.jl +``` + +--- + +## Running tests + +Run the test suite directly — works on all Julia versions: + +```bash +# From the terminal, inside the InferelatorJL directory: +julia --project=. test/runtests.jl +``` + +> **Why not `] test`?** +> In Julia ≥ 1.10, `Pkg.test()` creates a sandbox environment and calls +> `check_registered()` on all test dependencies. Stdlib packages like `Test` +> are not in the General registry, so this fails with +> `"expected package Test to be registered"`. +> Running the test file directly with `julia --project=.` bypasses the sandbox +> and works correctly on all Julia versions. + +--- + +## Working across machines and Julia versions + +`Manifest.toml` is **gitignored** — it is machine and Julia-version specific. +Only `Project.toml` is committed. Each machine generates its own Manifest. + +| Machine | What to do once | What gets committed | +|---|---|---| +| Local Mac (1.12) | `] instantiate` + `] build PyCall` | nothing (Manifest is gitignored) | +| Cluster (1.7.3) | `Pkg.instantiate()` + `Pkg.build("PyCall")` | nothing | +| Any new machine | same as above | nothing | + +**After `git pull` that changes `Project.toml`** (e.g. new dependency added), re-run: +```bash +julia --project=. -e 'using Pkg; Pkg.instantiate()' +``` + +**After upgrading Julia** to a new version, re-run both: +```bash +julia --project=. -e 'using Pkg; Pkg.instantiate()' +julia --project=. -e 'using Pkg; Pkg.build("PyCall")' +``` + +**If `Pkg.instantiate()` fails** with `"empty intersection"` or `"unsatisfiable requirements"`: +a specific package's compat bound in `Project.toml` is too narrow for the Julia version +being used. The error names the package. Widen its bound in `Project.toml` and re-run. +Example: `GLMNet = "0.4, 0.5, 0.6, 0.7"` → `"0.4, 0.5, 0.6, 0.7, 0.8"`. + +--- + +## Checking what changed / what's loaded + +```julia +# Which version is active? +using Pkg; Pkg.status("InferelatorJL") + +# Where is it loaded from? +pathof(InferelatorJL) + +# What does the package export? +names(InferelatorJL) + +# What fields does a struct have? +fieldnames(GeneExpressionData) +fieldnames(GrnData) + +# Inspect a loaded struct at runtime +data = GeneExpressionData() +propertynames(data) +``` + +--- + +## Switching between dev and release versions + +```julia +# Currently using dev (your local folder): +] status InferelatorJL # shows InferelatorJL [path] /Users/owop7y/Desktop/InferelatorJL + +# Switch to the released version (once it is published): +] free InferelatorJL # removes the dev pin +] add InferelatorJL # installs from registry + +# Switch back to dev: +] dev /Users/owop7y/Desktop/InferelatorJL +``` + +--- + +## Project environments (keeping work isolated) + +Every directory can have its own `Project.toml`. To work inside a specific project +environment (e.g., a collaborator's analysis folder): + +```julia +] activate /path/to/project # switch to that environment +] status # see what is installed there +] activate # return to your default environment (@v1.X) +``` + +When you `] dev .` from inside a project folder, the package is only registered +in that project, not globally. + +--- + +## Common errors and what they mean + +| Error | Cause | Fix | +|---|---|---| +| `UndefVarError: InferelatorJL not defined` | Package not loaded | `using InferelatorJL` | +| `UndefVarError: Revise not defined` | Revise not installed | `] add Revise` | +| `Cannot redefine struct` | Changed `Types.jl` | Restart Julia | +| `MethodError: no method matching ...` | Wrong argument types or order | Check function signature with `?functionname` | +| `KeyError` on a dict field | Field name wrong | `fieldnames(StructType)` to check | +| `PyCall not properly installed` | PyCall not built for this Julia version | `] build PyCall` | +| `expected package Test to be registered` | `] test` sandbox bug in Julia ≥ 1.10 | Use `julia --project=. test/runtests.jl` instead | +| `empty intersection between X@Y.Z and project compatibility` | Compat bound too narrow for installed version | Widen bound for that package in `Project.toml`, then `] instantiate` | +| `Warning: Package X does not have Y in its dependencies` | Missing dep in `Project.toml` | Add it with `] add Y` | + +--- + +## Package structure quick reference + +``` +src/ + InferelatorJL.jl ← module entry point, all includes and using statements + Types.jl ← ALL struct definitions (edit here for new fields) + API.jl ← public API (loadData, buildNetwork, etc.) + data/ ← data loading functions + prior/ ← TF merging + utils/ ← DataUtils, NetworkIO, PartialCorrelation + grn/ ← pipeline core (PrepareGRN, BuildGRN, AggregateNetworks, RefineTFA) + metrics/ ← PR/ROC evaluation and plotting + +examples/ + interactive_pipeline.jl ← step-by-step, public API (your main reference) + interactive_pipeline_dev.jl ← step-by-step, internal calls (your dev reference) + run_pipeline.jl ← function-wrapped, public API + run_pipeline_dev.jl ← function-wrapped, internal calls + utilityExamples.jl ← utility function demos, no real data needed + plotPR.jl ← PR curve evaluation + +test/ + runtests.jl ← unit tests (run with: julia --project=. test/runtests.jl) +``` + +--- + +## Adding a new exported function + +1. Write the function in the appropriate `src/` file. +2. Add `export myFunction` to the export block in `src/InferelatorJL.jl`. +3. Add a docstring above the function definition (Julia uses `"""..."""`). +4. If it is a utility function without real-data requirements, add an example + to `examples/utilityExamples.jl`. + +--- + +## Adding a new dependency + +1. `] add PackageName` — this updates both `Project.toml` and `Manifest.toml`. +2. Add `using PackageName` in the appropriate `src/` file. +3. Add a `[compat]` bound in `Project.toml` for the new package. +4. **Commit only `Project.toml`** — `Manifest.toml` is gitignored and should not be committed. + Other machines will regenerate their own Manifest via `Pkg.instantiate()`. diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 0000000..c753c5b --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,968 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.12.5" +manifest_format = "2.0" +project_hash = "d561f5e18fb60a24178d6af859af5197e5eb087c" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "35ea197a51ce46fcd01c4a44befce0578a1aaeca" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.5.0" +weakdeps = ["SparseArrays", "StaticArrays"] + + [deps.Adapt.extensions] + AdaptSparseArraysExt = "SparseArrays" + AdaptStaticArraysExt = "StaticArrays" + +[[deps.AliasTables]] +deps = ["PtrArrays", "Random"] +git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff" +uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" +version = "1.1.3" + +[[deps.ArgParse]] +deps = ["Logging", "TextWrap"] +git-tree-sha1 = "22cf435ac22956a7b45b0168abbc871176e7eecc" +uuid = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" +version = "1.2.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.2" + +[[deps.Arrow]] +deps = ["ArrowTypes", "BitIntegers", "CodecLz4", "CodecZstd", "ConcurrentUtilities", "DataAPI", "Dates", "EnumX", "Mmap", "PooledArrays", "SentinelArrays", "StringViews", "Tables", "TimeZones", "TranscodingStreams", "UUIDs"] +git-tree-sha1 = "4a69a3eadc1f7da78d950d1ef270c3a62c1f7e01" +uuid = "69666777-d1a9-59fb-9406-91d4454c9d45" +version = "2.8.1" + +[[deps.ArrowTypes]] +deps = ["Sockets", "UUIDs"] +git-tree-sha1 = "404265cd8128a2515a81d5eae16de90fdef05101" +uuid = "31f734f8-188a-4ce0-8406-c8a06bd891cd" +version = "2.3.0" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.11.0" + +[[deps.AxisAlgorithms]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"] +git-tree-sha1 = "01b8ccb13d68535d73d2b0c23e39bd23155fb712" +uuid = "13072b0f-2c55-5437-9ae7-d433b7a33950" +version = "1.1.0" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" + +[[deps.BitIntegers]] +deps = ["Random"] +git-tree-sha1 = "091d591a060e43df1dd35faab3ca284925c48e46" +uuid = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1" +version = "0.3.7" + +[[deps.CSV]] +deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] +git-tree-sha1 = "8d8e0b0f350b8e1c91420b5e64e5de774c2f0f4d" +uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +version = "0.10.16" + +[[deps.CategoricalArrays]] +deps = ["Compat", "DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"] +git-tree-sha1 = "a6f644eb7bbc0171286f0f3ad1ffde8f04be7b83" +uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597" +version = "1.1.0" + + [deps.CategoricalArrays.extensions] + CategoricalArraysArrowExt = "Arrow" + CategoricalArraysJSONExt = "JSON" + CategoricalArraysRecipesBaseExt = "RecipesBase" + CategoricalArraysSentinelArraysExt = "SentinelArrays" + CategoricalArraysStatsBaseExt = "StatsBase" + CategoricalArraysStructTypesExt = "StructTypes" + + [deps.CategoricalArrays.weakdeps] + Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" + JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" + RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" + SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" + StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" + StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "e4c6a16e77171a5f5e25e9646617ab1c276c5607" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.26.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + +[[deps.CodecLz4]] +deps = ["Lz4_jll", "TranscodingStreams"] +git-tree-sha1 = "d58afcd2833601636b48ee8cbeb2edcb086522c2" +uuid = "5ba52731-8f18-5e0d-9241-30f10d1ec561" +version = "0.4.6" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "962834c22b66e32aa10f7611c08c8ca4e20749a9" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.8" + +[[deps.CodecZstd]] +deps = ["TranscodingStreams", "Zstd_jll"] +git-tree-sha1 = "da54a6cd93c54950c15adf1d336cfd7d71f51a56" +uuid = "6b39b394-51ab-5f42-8807-6242bab2b4c2" +version = "0.8.7" + +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "67e11ee83a43eb71ddc950302c53bf33f0690dfe" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.12.1" +weakdeps = ["StyledStrings"] + + [deps.ColorTypes.extensions] + StyledStringsExt = "StyledStrings" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "37ea44092930b1811e666c3bc38065d7d87fcc74" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.13.1" + +[[deps.Combinatorics]] +git-tree-sha1 = "c761b00e7755700f9cdf5b02039939d1359330e1" +uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +version = "1.1.0" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "9d8a54ce4b17aa5bdce0ea5c34bc5e7c340d16ad" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.18.1" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.3.0+1" + +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "21d088c496ea22914fe80906eb5bce65755e5ec8" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.5.1" + +[[deps.Conda]] +deps = ["Downloads", "JSON", "VersionParsing"] +git-tree-sha1 = "8f06b0cfa4c514c7b9546756dbae91fcfbc92dc9" +uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d" +version = "1.10.3" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "d8928e9169ff76c6281f39a659f9bca3a573f24c" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.8.1" + +[[deps.DataStructures]] +deps = ["OrderedCollections"] +git-tree-sha1 = "e86f4a2805f7f19bec5129bc9150c38208e5dc23" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.19.4" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +version = "1.11.0" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +version = "1.11.0" + +[[deps.Distributions]] +deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] +git-tree-sha1 = "fbcc7610f6d8348428f722ecbe0e6cfe22e672c6" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.25.123" + + [deps.Distributions.extensions] + DistributionsChainRulesCoreExt = "ChainRulesCore" + DistributionsDensityInterfaceExt = "DensityInterface" + DistributionsTestExt = "Test" + + [deps.Distributions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.DocStringExtensions]] +git-tree-sha1 = "7442a5dfe1ebb773c29cc2962a8980f47221d76c" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.5" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.7.0" + +[[deps.EnumX]] +git-tree-sha1 = "c49898e8438c828577f04b92fc9368c388ac783c" +uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +version = "1.0.7" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.FileIO]] +deps = ["Pkg", "Requires", "UUIDs"] +git-tree-sha1 = "6522cfb3b8fe97bec632252263057996cbd3de20" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.18.0" + + [deps.FileIO.extensions] + HTTPExt = "HTTP" + + [deps.FileIO.weakdeps] + HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" + +[[deps.FilePathsBase]] +deps = ["Compat", "Dates"] +git-tree-sha1 = "3bab2c5aa25e7840a4b065805c0cdfc01f3068d2" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.9.24" +weakdeps = ["Mmap", "Test"] + + [deps.FilePathsBase.extensions] + FilePathsBaseMmapExt = "Mmap" + FilePathsBaseTestExt = "Test" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +version = "1.11.0" + +[[deps.FillArrays]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "2f979084d1e13948a3352cf64a25df6bd3b4dca3" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.16.0" +weakdeps = ["PDMats", "SparseArrays", "StaticArrays", "Statistics"] + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStaticArraysExt = "StaticArrays" + FillArraysStatisticsExt = "Statistics" + +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.5" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" +version = "1.11.0" + +[[deps.GLM]] +deps = ["Distributions", "LinearAlgebra", "Printf", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "StatsModels"] +git-tree-sha1 = "3bcb30438ee1655e3b9c42d97544de7addc9c589" +uuid = "38e38edf-8417-5370-95a0-9cbb8c7f171a" +version = "1.9.3" + +[[deps.GLMNet]] +deps = ["DataFrames", "Distributed", "Distributions", "Printf", "Random", "SparseArrays", "StatsBase", "glmnet_jll"] +git-tree-sha1 = "b873c384d3490304c18224b1d5554cdebaafb60b" +uuid = "8d5ece8b-de18-5317-b113-243142960cc6" +version = "0.7.4" + +[[deps.HashArrayMappedTries]] +git-tree-sha1 = "2eaa69a7cab70a52b9687c8bf950a5a93ec895ae" +uuid = "076d061b-32b6-4027-95e0-9a2c6f6d7e74" +version = "0.2.0" + +[[deps.HypergeometricFunctions]] +deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "68c173f4f449de5b438ee67ed0c9c748dc31a2ec" +uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" +version = "0.3.28" + +[[deps.InferelatorJL]] +deps = ["ArgParse", "Arrow", "CSV", "CategoricalArrays", "Colors", "DataFrames", "Dates", "DelimitedFiles", "Distributions", "FileIO", "GLM", "GLMNet", "InlineStrings", "Interpolations", "JLD2", "LinearAlgebra", "Measures", "NamedArrays", "OrderedCollections", "Printf", "ProgressBars", "PyPlot", "Random", "SparseArrays", "Statistics", "StatsBase", "TickTock"] +path = "." +uuid = "436bd8d6-fc45-48f7-bc1f-d4dd5aa384ad" +version = "0.1.0" + +[[deps.InlineStrings]] +git-tree-sha1 = "8f3d257792a522b4601c24a577954b0a8cd7334d" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.5" +weakdeps = ["ArrowTypes", "Parsers"] + + [deps.InlineStrings.extensions] + ArrowTypesExt = "ArrowTypes" + ParsersExt = "Parsers" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +version = "1.11.0" + +[[deps.Interpolations]] +deps = ["Adapt", "AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"] +git-tree-sha1 = "88a101217d7cb38a7b481ccd50d21876e1d1b0e0" +uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +version = "0.15.1" + + [deps.Interpolations.extensions] + InterpolationsUnitfulExt = "Unitful" + + [deps.Interpolations.weakdeps] + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.InvertedIndices]] +git-tree-sha1 = "6da3c4316095de0f5ee2ebd875df8721e7e0bdbe" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.1" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "b2d91fe939cae05960e760110b328288867b5758" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.6" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLD2]] +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "ScopedValues", "TranscodingStreams"] +git-tree-sha1 = "d97791feefda45729613fafeccc4fbef3f539151" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.5.15" + + [deps.JLD2.extensions] + UnPackExt = "UnPack" + + [deps.JLD2.weakdeps] + UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "0533e564aae234aff59ab625543145446d8b6ec2" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.7.1" + +[[deps.JSON]] +deps = ["Dates", "Logging", "Parsers", "PrecompileTools", "StructUtils", "UUIDs", "Unicode"] +git-tree-sha1 = "b3ad4a0255688dcb895a52fafbaae3023b588a90" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "1.4.0" +weakdeps = ["ArrowTypes"] + + [deps.JSON.extensions] + JSONArrowExt = ["ArrowTypes"] + +[[deps.JuliaSyntaxHighlighting]] +deps = ["StyledStrings"] +uuid = "ac6e5ff7-fb65-4e79-a425-ec3bc9c03011" +version = "1.12.0" + +[[deps.LaTeXStrings]] +git-tree-sha1 = "dda21b8cbd6a6c40d9d02a73230f9d70fed6918c" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.4.0" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "OpenSSL_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.15.0+0" + +[[deps.LibGit2]] +deps = ["LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +version = "1.11.0" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "OpenSSL_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.9.0+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "OpenSSL_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.3+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +version = "1.11.0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +version = "1.12.0" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "13ca9e2586b89836fd20cccf56e57e2b9ae7f38f" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.29" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +version = "1.11.0" + +[[deps.Lz4_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "191686b1ac1ea9c89fc52e996ad15d1d241d1e33" +uuid = "5ced341a-0733-55b8-9ab6-a4889d929147" +version = "1.10.1+0" + +[[deps.MacroTools]] +git-tree-sha1 = "1e0228a030642014fe5cfe68c2c0a818f9e3f522" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.16" + +[[deps.Markdown]] +deps = ["Base64", "JuliaSyntaxHighlighting", "StyledStrings"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" + +[[deps.Measures]] +git-tree-sha1 = "b513cedd20d9c914783d8ad83d08120702bf2c77" +uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e" +version = "0.3.3" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.2.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" +version = "1.11.0" + +[[deps.Mocking]] +deps = ["Compat", "ExprTools"] +git-tree-sha1 = "2c140d60d7cb82badf06d8783800d0bcd1a7daa2" +uuid = "78c3b35d-d492-501b-9361-3d52fe80e533" +version = "0.8.1" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2025.11.4" + +[[deps.NamedArrays]] +deps = ["Combinatorics", "DelimitedFiles", "InvertedIndices", "LinearAlgebra", "OrderedCollections", "Random", "Requires", "SparseArrays", "Statistics"] +git-tree-sha1 = "33d258318d9e049d26c02ca31b4843b2c851c0b0" +uuid = "86f7a689-2022-50b4-a561-43c23ac3c673" +version = "0.10.5" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.3.0" + +[[deps.OffsetArrays]] +git-tree-sha1 = "117432e406b5c023f665fa73dc26e79ec3630151" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.17.0" +weakdeps = ["Adapt"] + + [deps.OffsetArrays.extensions] + OffsetArraysAdaptExt = "Adapt" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.29+0" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.7+0" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.5.4+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1346c9208249809840c91b26703912dff463d335" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.6+0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "05868e21324cede2207c6f0f466b4bfef6d5e7ee" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.8.1" + +[[deps.PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "e4cff168707d441cd6bf3ff7e4832bdf34278e4a" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.11.37" +weakdeps = ["StatsBase"] + + [deps.PDMats.extensions] + StatsBaseExt = "StatsBase" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "7d2f8f21da5db6a806faf7b9b292296da42b2810" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.3" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.12.1" +weakdeps = ["REPL"] + + [deps.Pkg.extensions] + REPLExt = "REPL" + +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.3" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "07a921781cab75691315adc645096ed5e370cb77" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.3.3" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "8b770b60760d4451834fe79dd483e318eee709c4" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.5.2" + +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "REPL", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "624de6279ab7d94fc9f672f0068107eb6619732c" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "3.3.2" + + [deps.PrettyTables.extensions] + PrettyTablesTypstryExt = "Typstry" + + [deps.PrettyTables.weakdeps] + Typstry = "f0ed7684-a786-439e-b1e3-3b82803b501e" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +version = "1.11.0" + +[[deps.ProgressBars]] +deps = ["Printf"] +git-tree-sha1 = "b437cdb0385ed38312d91d9c00c20f3798b30256" +uuid = "49802e3a-d2f1-5c88-81d8-b72133a6f568" +version = "1.5.1" + +[[deps.PtrArrays]] +git-tree-sha1 = "4fbbafbc6251b883f4d2705356f3641f3652a7fe" +uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" +version = "1.4.0" + +[[deps.PyCall]] +deps = ["Conda", "Dates", "Libdl", "LinearAlgebra", "MacroTools", "Serialization", "VersionParsing"] +git-tree-sha1 = "9816a3826b0ebf49ab4926e2b18842ad8b5c8f04" +uuid = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" +version = "1.96.4" + +[[deps.PyPlot]] +deps = ["Colors", "LaTeXStrings", "PyCall", "Sockets", "Test", "VersionParsing"] +git-tree-sha1 = "d2c2b8627bbada1ba00af2951946fb8ce6012c05" +uuid = "d330b81b-6aea-500a-939a-2ce795aea3ee" +version = "2.11.6" + +[[deps.QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "9da16da70037ba9d701192e27befedefb91ec284" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.11.2" + + [deps.QuadGK.extensions] + QuadGKEnzymeExt = "Enzyme" + + [deps.QuadGK.weakdeps] + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + +[[deps.REPL]] +deps = ["InteractiveUtils", "JuliaSyntaxHighlighting", "Markdown", "Sockets", "StyledStrings", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" +version = "1.11.0" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" + +[[deps.Ratios]] +deps = ["Requires"] +git-tree-sha1 = "1342a47bf3260ee108163042310d26f2be5ec90b" +uuid = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439" +version = "0.4.5" +weakdeps = ["FixedPointNumbers"] + + [deps.Ratios.extensions] + RatiosFixedPointNumbersExt = "FixedPointNumbers" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "62389eeff14780bfe55195b7204c0d8738436d64" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.1" + +[[deps.Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "5b3d50eb374cea306873b371d3f8d3915a018f0b" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.9.0" + +[[deps.Rmath_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.5.1+0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.ScopedValues]] +deps = ["HashArrayMappedTries", "Logging"] +git-tree-sha1 = "ac4b837d89a58c848e85e698e2a2514e9d59d8f6" +uuid = "7e506255-f358-4e82-b7e4-beb19740aa63" +version = "1.6.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "9b81b8393e50b7d4e6d0a9f14e192294d3b7c109" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.3.0" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "ebe7e59b37c400f694f52b58c93d26201387da70" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.9" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +version = "1.11.0" + +[[deps.SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" +version = "1.11.0" + +[[deps.ShiftedArrays]] +git-tree-sha1 = "503688b59397b3307443af35cd953a13e8005c16" +uuid = "1277b4bf-5013-50f5-be3d-901d8477a67a" +version = "2.0.0" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +version = "1.11.0" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "64d974c2e6fdf07f8155b5b2ca2ffa9069b608d9" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.2.2" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.12.0" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "2700b235561b0335d5bef7097a111dc513b8655e" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.7.2" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "246a8bb2e6667f832eea063c3a56aef96429a3db" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.18" +weakdeps = ["ChainRulesCore", "Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "6ab403037779dae8c514bad259f32a447262455a" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.4" + +[[deps.Statistics]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.11.1" +weakdeps = ["SparseArrays"] + + [deps.Statistics.extensions] + SparseArraysExt = ["SparseArrays"] + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "178ed29fd5b2a2cfc3bd31c13375ae925623ff36" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.8.0" + +[[deps.StatsBase]] +deps = ["AliasTables", "DataAPI", "DataStructures", "IrrationalConstants", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "aceda6f4e598d331548e04cc6b2124a6148138e3" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.10" + +[[deps.StatsFuns]] +deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "91f091a8716a6bb38417a6e6f274602a19aaa685" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "1.5.2" + + [deps.StatsFuns.extensions] + StatsFunsChainRulesCoreExt = "ChainRulesCore" + StatsFunsInverseFunctionsExt = "InverseFunctions" + + [deps.StatsFuns.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.StatsModels]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Printf", "REPL", "ShiftedArrays", "SparseArrays", "StatsAPI", "StatsBase", "StatsFuns", "Tables"] +git-tree-sha1 = "08786db4a1346d17d0a8d952d2e66fd00fa18192" +uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d" +version = "0.7.9" + +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "d05693d339e37d6ab134c5ab53c29fce5ee5d7d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.4.4" + +[[deps.StringViews]] +git-tree-sha1 = "f2dcb92855b31ad92fe8f079d4f75ac57c93e4b8" +uuid = "354b36f9-a18e-4713-926e-db85100087ba" +version = "1.3.7" + +[[deps.StructUtils]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "fa95b3b097bcef5845c142ea2e085f1b2591e92c" +uuid = "ec057cc2-7a8d-4b58-b3b3-92acb9f63b42" +version = "2.7.1" + + [deps.StructUtils.extensions] + StructUtilsMeasurementsExt = ["Measurements"] + StructUtilsStaticArraysCoreExt = ["StaticArraysCore"] + StructUtilsTablesExt = ["Tables"] + + [deps.StructUtils.weakdeps] + Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" + StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + +[[deps.StyledStrings]] +uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b" +version = "1.11.0" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.8.3+2" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TZJData]] +deps = ["Artifacts"] +git-tree-sha1 = "72df96b3a595b7aab1e101eb07d2a435963a97e2" +uuid = "dc5dba14-91b3-4cab-a142-028a31da12f7" +version = "1.5.0+2025b" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "f2c1efbc8f3a609aadf318094f8fc5204bdaf344" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.12.1" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +version = "1.11.0" + +[[deps.TextWrap]] +git-tree-sha1 = "43044b737fa70bc12f6105061d3da38f881a3e3c" +uuid = "b718987f-49a8-5099-9789-dcd902bef87d" +version = "1.0.2" + +[[deps.TickTock]] +deps = ["Dates"] +git-tree-sha1 = "385ff4318d1159050cb129f908804ff95b830de0" +uuid = "9ff05d80-102d-5586-aa04-3a8bd1a90d20" +version = "1.3.0" + +[[deps.TimeZones]] +deps = ["Artifacts", "Dates", "Downloads", "InlineStrings", "Mocking", "Printf", "Scratch", "TZJData", "Unicode", "p7zip_jll"] +git-tree-sha1 = "d422301b2a1e294e3e4214061e44f338cafe18a2" +uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53" +version = "1.22.2" + + [deps.TimeZones.extensions] + TimeZonesRecipesBaseExt = "RecipesBase" + + [deps.TimeZones.weakdeps] + RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.11.3" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +version = "1.11.0" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +version = "1.11.0" + +[[deps.VersionParsing]] +git-tree-sha1 = "58d6e80b4ee071f5efd07fda82cb9fbe17200868" +uuid = "81def892-9a0e-5fdd-b105-ffc91e053289" +version = "1.3.0" + +[[deps.WeakRefStrings]] +deps = ["DataAPI", "InlineStrings", "Parsers"] +git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" +uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" +version = "1.4.2" + +[[deps.WoodburyMatrices]] +deps = ["LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "248a7031b3da79a127f14e5dc5f417e26f9f6db7" +uuid = "efce3f68-66dc-5838-9240-27a6d6f5f9b6" +version = "1.1.0" + +[[deps.WorkerUtilities]] +git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" +uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" +version = "1.6.1" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.3.1+2" + +[[deps.Zstd_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "446b23e73536f84e8037f5dce465e92275f6a308" +uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" +version = "1.5.7+1" + +[[deps.glmnet_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "31adae3b983b579a1fbd7cfd43a4bc0d224c2f5a" +uuid = "78c6b45d-5eaf-5d68-bcfb-a5a2cb06c27f" +version = "2.0.13+0" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.15.0+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.64.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.7.0+0" diff --git a/UtilsRprog/DownSamplingObject.R b/Processing/DownSamplingObject.R similarity index 100% rename from UtilsRprog/DownSamplingObject.R rename to Processing/DownSamplingObject.R diff --git a/UtilsRprog/saveNormCountsArrowFIle.R b/Processing/saveNormCountsArrowFIle.R similarity index 100% rename from UtilsRprog/saveNormCountsArrowFIle.R rename to Processing/saveNormCountsArrowFIle.R diff --git a/Project.toml b/Project.toml new file mode 100755 index 0000000..2d32763 --- /dev/null +++ b/Project.toml @@ -0,0 +1,61 @@ +name = "InferelatorJL" +uuid = "436bd8d6-fc45-48f7-bc1f-d4dd5aa384ad" +version = "0.1.0" + +[deps] +ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" +Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" +Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" +GLMNet = "8d5ece8b-de18-5317-b113-243142960cc6" +InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Measures = "442fdcdd-2543-5da2-b0f3-8c86c306513e" +NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" +PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +TickTock = "9ff05d80-102d-5586-aa04-3a8bd1a90d20" + +[compat] +julia = "1.7" +ArgParse = "1" +Arrow = "2" +CSV = "0.10" +CategoricalArrays = "0.10, 1" +Colors = "0.12, 0.13" +DataFrames = "1" +Distributions = "0.25" +FileIO = "1" +GLM = "1" +GLMNet = "0.4, 0.5, 0.6, 0.7" +InlineStrings = "1" +Interpolations = "0.13, 0.14, 0.15" +JLD2 = "0.4, 0.5" +Measures = "0.3" +NamedArrays = "0.9, 0.10" +OrderedCollections = "1" +ProgressBars = "1" +PyPlot = "2" +StatsBase = "0.33, 0.34" +TickTock = "1" + +[extras] +Test = "8dfed614-e22c-358a-adea-af005a22d00b" + +[targets] +test = ["Test"] diff --git a/README.md b/README.md index 6f75c6e..1728d3b 100755 --- a/README.md +++ b/README.md @@ -1,14 +1,218 @@ -# Inferelator_Julia +# InferelatorJL -This repository contains a workflow for inference of transcriptional regulatory networks (TRNs) from gene expression data and prior information, as described in: +A Julia package for inference of transcriptional regulatory networks (GRNs) from +gene expression data and prior chromatin accessibility or binding information, +using the **mLASSO-StARS** algorithm. -Miraldi et al., Leveraging chromatin accessibility for transcriptional regulatory network inference in T Helper 17 Cells. +![Pipeline diagram](pipeline_diagram.png) -The main workflow is found here: \ -[workflow.jl](https://github.com/MiraldiLab/Inferelator_Julia/blob/main/Testing/workflow.jl) \ -This workflow can benefit from parallelization, with significant speedups from using multiple cores. The workflow can be run on the BMI cluster with n cores using the bat file: \ -[workflow.bat](https://github.com/MiraldiLab/Inferelator_Julia/blob/main/Testing/workflow.bat) \ +--- +## Overview +InferelatorJL infers genome-scale gene regulatory networks (GRNs) by combining +gene expression data with a prior network (e.g., from ATAC-seq or ChIP-seq) to +identify transcription factor (TF) → target gene regulatory relationships. +The method uses **multi-scale LASSO with Stability Approach to Regularization +Selection (mLASSO-StARS)**, a penalized regression framework in which the prior +network biases edge selection and stability selection across bootstrap subsamples +determines a data-driven sparsity level. -Screen Shot 2023-03-14 at 1 12 10 PM +**Key features:** + +- Supports bulk RNA-seq, pseudobulk, and single-cell expression data +- Estimates **TF activity (TFA)** via least squares as an alternative to raw TF mRNA +- Builds separate networks for TFA and TF mRNA predictors, then combines them +- Prior-weighted LASSO penalties reduce false positives for prior-supported edges +- Network-level or gene-level instability thresholding +- Built-in PR / ROC evaluation against gold-standard interaction sets +- Fully multithreaded subsampling loop + +--- + +## Installation + +**Requirements:** Julia ≥ 1.7, Python (for PyPlot/matplotlib) + +```julia +# From the Julia REPL: +] add https://github.com/miraldilab/InferelatorJL.jl + +# Install all dependencies: +] instantiate +``` + +**For development / local installation:** +```julia +] dev /path/to/InferelatorJL +] instantiate +``` + +--- + +## Quick start + +```julia +using InferelatorJL + +# Step 1 — load expression data +data = loadData(exprFile, targFile, regFile) + +# Steps 2–3 — merge degenerate TFs, build prior, estimate TFA +priorData, mergedTFs = loadPrior(data, priorFile) +estimateTFA(priorData, data; outputDir = dirOut) + +# Step 4 — infer GRN (TFA predictors) +buildNetwork(data, priorData; + tfaMode = true, + outputDir = joinpath(dirOut, "TFA")) + +# Step 5 — aggregate TFA + mRNA networks +aggregateNetworks([tfaEdges, mrnaEdges]; + method = :max, + outputDir = joinpath(dirOut, "Combined")) + +# Step 6 — refine TFA using the consensus network as a new prior +refineTFA(combinedNetFile, data, mergedTFs; outputDir = dirOut) +``` + +See [`examples/interactive_pipeline.jl`](examples/interactive_pipeline.jl) for a +fully annotated step-by-step example. + +--- + +## Pipeline + +The full pipeline has six steps. Each step is a single function call. + +| Step | Function | Description | +|---|---|---| +| 1 | `loadData` | Load expression matrix; filter to target genes and regulators | +| 2–3 | `loadPrior` + `estimateTFA` | Process prior, merge degenerate TFs, estimate TF activity via least squares | +| 4 | `buildNetwork` | Run mLASSO-StARS for one predictor mode (TFA or TF mRNA) | +| 5 | `aggregateNetworks` | Combine TFA and mRNA edge lists into a consensus network | +| 6 | `refineTFA` | Re-estimate TFA using the consensus network as a data-driven prior | +| — | `evaluateNetwork` | Evaluate any network against gold standards (PR/ROC curves) | + +Steps 4–6 are typically run twice (TFA mode and TF mRNA mode) and the results +aggregated in Step 5. + +--- + +## Input files + +| File | Format | Description | +|---|---|---| +| Expression matrix | Tab-delimited TSV or Apache Arrow | Genes × samples; genes in rows, samples in columns; first column is gene names | +| Target gene list | Plain text, one gene per line | Genes to model as regression targets | +| Regulator list | Plain text, one TF per line | Candidate transcription factors | +| Prior network | Sparse TSV (TF × gene, empty first header) | Prior regulatory connections; values are edge weights (binary or continuous) | +| Prior penalties | Same format as prior network | Prior(s) used to set per-edge LASSO penalties; can be the same as prior network | +| TFA gene list | Plain text, one gene per line | Optional: restrict TFA estimation to a gene subset | + +--- + +## Output files + +All outputs are written under the directory specified by `outputDir`. + +| File | Description | +|---|---| +| `edges.tsv` | Full ranked edge table: TF, gene, signed quantile, stability, partial correlation, inPrior flag | +| `edges_subset.tsv` | Top edges after applying the `meanEdgesPerGene` cap | +| `instability_*.jld2` | Raw instability arrays across the λ grid (for diagnostics) | +| `combined_.tsv` | Aggregated network (long format) | +| `combined__sp.tsv` | Aggregated network (sparse prior format, for downstream use) | + +--- + +## Key parameters + +| Parameter | Default | Description | +|---|---|---| +| `totSS` | 80 | Total bootstrap subsamples for instability estimation | +| `subsampleFrac` | 0.63 | Fraction of samples drawn per subsample | +| `targetInstability` | 0.05 | Instability threshold for λ selection | +| `lambdaBias` | `[0.5]` | Penalty reduction factor for prior-supported edges (0 = no prior, 1 = uniform) | +| `meanEdgesPerGene` | 20 | Maximum retained edges per target gene | +| `instabilityLevel` | `"Network"` | `"Network"`: single λ for all genes; `"Gene"`: per-gene λ | +| `zScoreTFA` | `true` | Z-score expression before TFA estimation | +| `zScoreLASSO` | `true` | Z-score expression before LASSO regression | +| `method` | `:max` | Network aggregation rule: `:max`, `:mean`, or `:min` stability | + +--- + +## API reference + +### Data loading +```julia +data = loadData(exprFile, targFile, regFile; tfaGeneFile="", epsilon=0.01) +priorData, mergedTFs = loadPrior(data, priorFile; minTargets=3) +``` + +### Core pipeline +```julia +estimateTFA(priorData, data; edgeSS=0, zScoreTFA=true, outputDir=".") +buildNetwork(data, priorData; tfaMode=true, totSS=80, lambdaBias=[0.5], ...) +aggregateNetworks(netFiles; method=:max, meanEdgesPerGene=20, outputDir=".") +refineTFA(combinedNetFile, data, mergedTFs; zScoreTFA=true, outputDir=".") +``` + +### Evaluation +```julia +evaluateNetwork(gsFile, netFile; gsRegsFile="", targGeneFile="", outputDir=".") +``` + +### Utilities +```julia +convertToLong(df) # wide prior matrix → long format +convertToWide(df; indices=(2,1,3)) # long → wide +frobeniusNormalize(M, :column) # normalize matrix columns or rows +binarizeNumeric!(df) # continuous prior → binary +mergeDFs(dfs, :Gene, "sum") # merge prior DataFrames +completeDF(df, :Gene, allGenes, allTFs) # align to full gene/TF universe +writeTSVWithEmptyFirstHeader(df, path) # write sparse prior format +check_column_norms(M) # verify unit column norms +writeNetworkTable!(buildGrn; outputDir=".") # write edges.tsv from BuildGrn struct +saveData(data, tfaData, grnData, buildGrn, dir, "checkpoint.jld2") +``` + +--- + +## Examples + +| File | Description | +|---|---| +| [`interactive_pipeline.jl`](examples/interactive_pipeline.jl) | Full pipeline, step-by-step in the REPL, public API | +| [`run_pipeline.jl`](examples/run_pipeline.jl) | Full pipeline wrapped in `runInferelator()` for batch use | +| [`utilityExamples.jl`](examples/utilityExamples.jl) | All utility functions demonstrated with synthetic data; no input files needed | +| [`plotPR.jl`](examples/plotPR.jl) | Evaluate networks against gold standards and generate PR curve plots | + +--- + +## Data structures + +| Struct | Populated by | Key fields | +|---|---|---| +| `GeneExpressionData` | `loadData` | `expressionMat`, `geneNames`, `sampleNames`, `tfNames` | +| `PriorTFAData` | `loadPrior`, `estimateTFA` | `priorMat`, `medTfas`, `tfNames`, `targetGenes` | +| `mergedTFsResult` | `loadPrior` | `mergedTFs`, `tfNames`, `mergeTfLocVec` | +| `GrnData` | `buildNetwork` internals | `predictorMat`, `penaltyMat`, `stabilityMat`, `subsampleMat` | +| `BuildGrn` | `buildNetwork` | `regs`, `targs`, `signedQuantile`, `rankings`, `networkMat` | + +--- + +## Citation + +If you use InferelatorJL in your work, please cite: + +> Miraldi ER, Pokrovskii M, Watters A, et al. +> **Leveraging chromatin accessibility for transcriptional regulatory network inference in T Helper 17 Cells.** +> *PLOS Computational Biology*, 2019. +> https://doi.org/10.1371/journal.pcbi.1006979 + +--- + +## License + +See [LICENSE](LICENSE) for details. diff --git a/archive/GRN.jl b/archive/GRN.jl new file mode 100755 index 0000000..82ccbf8 --- /dev/null +++ b/archive/GRN.jl @@ -0,0 +1,116 @@ +module GRN + + include("utilsGRN.jl") + + using ..Data + using ..DataUtils + using ..PriorTFA + using ..MergeDegenerate + using ..NetworkIO + + using GLMNet + using Random, Statistics, StatsBase + using Distributions + using DataFrames, CategoricalArrays, SparseArrays, CSV, DelimitedFiles + using LinearAlgebra + using TickTock + using JLD2 + using ProgressBars + using Printf, Dates + using PyPlot + using Statistics + using CSV + using NamedArrays + using ArgParse + + mutable struct GrnData + predictorMat::Matrix{Float64} + penaltyMat::Matrix{Float64} + allPredictors::Vector{String} + subsamps::Matrix{Int64} + responseMat::Matrix{Float64} + maxLambdaNet::Float64 + minLambdaNet::Float64 + minLambdas::Matrix{Float64} + maxLambdas::Matrix{Float64} + netInstabilitiesUb::Vector{Float64} + netInstabilitiesLb::Vector{Float64} + instabilitiesUb::Matrix{Float64} + instabilitiesLb::Matrix{Float64} + netInstabilities::Vector{Float64} + geneInstabilities::Matrix{Float64} + lambdaRange::Vector{Float64} + lambdaRangeGene::Vector{Vector{Float64}} + stabilityMat::Array{Float64} + priorMatProcessed::Matrix{Float64} + betas::Array{Float64,3} + function GrnData() + return new( + Matrix{Float64}(undef, 0, 0), # predictorMat + Matrix{Float64}(undef, 0, 0), # penaltyMat + [], # allPredictors + Matrix{Int64}(undef, 0, 0), # subsamps + Matrix{Int64}(undef, 0, 0), # responseMat + 0.0, # maxLambdasNet + 0.0, # minLambdasNet + Matrix{Int64}(undef, 0, 0), # minLambdas + Matrix{Int64}(undef, 0, 0), # maxLambdas + [], # netInstabilitiesUb + [], # netInstabilitiesLb + Matrix{Int64}(undef, 0, 0), # instabilitiesUb + Matrix{Int64}(undef, 0, 0), # instabilitiesLb + [], # netInstabilities + Matrix{Int64}(undef, 0, 0), # geneInstabilities + [], # lambdaRange + Vector{Vector{Float64}}(undef, 0), # lambdaRangesGene + Matrix{Int64}(undef, 0, 0), # stabilityMat + Matrix{Float64}(undef, 0, 0), # priorMatProcessed + Array{Float64,3}(undef, 0, 0, 0) # betas + ) + end + end + + + mutable struct BuildGrn + networkStability::Matrix{Float64} + lambda::Union{Float64, Vector{Float64}} + targs::Vector{String} + regs::Vector{String} + rankings::Vector{Float64} + signedQuantile::Vector{Float64} + partialCorrelation::Vector{Float64} + inPrior::Vector{String} + networkMat::Matrix{Any} + networkMatSubset::Matrix{Any} + inPriorVec::Vector{Float64} + betas::Matrix{Float64} + function BuildGrn() + return new( + Matrix{Float64}(undef, 0, 0), # networkStability + 0.0, # lambda + [], # targs + [], # regs + [], # rankings + [], # signedQuantile + [], # partialCorrelation + [], # inPrior + Matrix{Float64}(undef, 0, 0), # networkMat + Matrix{Float64}(undef, 0, 0), # networkMatSubset + [], # inPriorVec + Matrix{Float64}(undef, 0, 0), # betas + [] # mergeTfLocVec + ) + end + end + + # Export only the functions that users need + export preparePredictorMat!, preparePenaltyMatrix!, constructSubsamples, bstarsWarmStart, bstartsEstimateInstability, + BuildGrn, GrnData, chooseLambda!, rankEdges!, combineGRNs, combineGRNS2 + + include("prepareGRN.jl") # functions that prepare predictor matrices etc. + include("buildGRN.jl") # functions that use GrnData / BuildGrn + include("aggregateNetworks.jl") # combineGRNs + include("refineTFA.jl") # combineGRNS2 + + +end \ No newline at end of file diff --git a/archive/Packages.txt b/archive/Packages.txt new file mode 100755 index 0000000..0a84b65 --- /dev/null +++ b/archive/Packages.txt @@ -0,0 +1,23 @@ +JLD2 +CSV +Arrow +DataFrames +OrderedCollections +Interpolations +InlineStrings +StatsPlots +PyPlot +Colors +Measures +GLMNet +Random +StatsBase +Distributions +CategoricalArrays +SparseArrays +TickTock +ProgressBars +NamedArrays +ArgParse +FileIO +GR \ No newline at end of file diff --git a/Testing/Packages.txt b/archive/Testing/Packages.txt similarity index 100% rename from Testing/Packages.txt rename to archive/Testing/Packages.txt diff --git a/Testing/R2_predict.bat b/archive/Testing/R2_predict.bat similarity index 100% rename from Testing/R2_predict.bat rename to archive/Testing/R2_predict.bat diff --git a/Testing/R2_predict.jl b/archive/Testing/R2_predict.jl similarity index 100% rename from Testing/R2_predict.jl rename to archive/Testing/R2_predict.jl diff --git a/Testing/combinePriors.jl b/archive/Testing/combinePriors.jl similarity index 100% rename from Testing/combinePriors.jl rename to archive/Testing/combinePriors.jl diff --git a/Testing/installRequirements.jl b/archive/Testing/installRequirements.jl similarity index 100% rename from Testing/installRequirements.jl rename to archive/Testing/installRequirements.jl diff --git a/Testing/misc/README.md b/archive/Testing/misc/README.md similarity index 100% rename from Testing/misc/README.md rename to archive/Testing/misc/README.md diff --git a/Testing/misc/calcPRinfTRNsMulti.jl b/archive/Testing/misc/calcPRinfTRNsMulti.jl similarity index 100% rename from Testing/misc/calcPRinfTRNsMulti.jl rename to archive/Testing/misc/calcPRinfTRNsMulti.jl diff --git a/Testing/misc/plotPR.jl b/archive/Testing/misc/plotPR.jl similarity index 100% rename from Testing/misc/plotPR.jl rename to archive/Testing/misc/plotPR.jl diff --git a/Testing/plotPR.jl b/archive/Testing/plotPR.jl similarity index 100% rename from Testing/plotPR.jl rename to archive/Testing/plotPR.jl diff --git a/Testing/plotPR_Old.jl b/archive/Testing/plotPR_Old.jl similarity index 100% rename from Testing/plotPR_Old.jl rename to archive/Testing/plotPR_Old.jl diff --git a/Testing/scratch/PR.jl b/archive/Testing/scratch/PR.jl similarity index 100% rename from Testing/scratch/PR.jl rename to archive/Testing/scratch/PR.jl diff --git a/Testing/scratch/TestingScript.jl b/archive/Testing/scratch/TestingScript.jl similarity index 100% rename from Testing/scratch/TestingScript.jl rename to archive/Testing/scratch/TestingScript.jl diff --git a/Testing/scratch/ToDo b/archive/Testing/scratch/ToDo similarity index 100% rename from Testing/scratch/ToDo rename to archive/Testing/scratch/ToDo diff --git a/Testing/scratch/calcPRinfTRNsMulti.jl b/archive/Testing/scratch/calcPRinfTRNsMulti.jl similarity index 100% rename from Testing/scratch/calcPRinfTRNsMulti.jl rename to archive/Testing/scratch/calcPRinfTRNsMulti.jl diff --git a/Testing/workflow.bat b/archive/Testing/workflow.bat similarity index 100% rename from Testing/workflow.bat rename to archive/Testing/workflow.bat diff --git a/Testing/workflow.jl b/archive/Testing/workflow.jl similarity index 100% rename from Testing/workflow.jl rename to archive/Testing/workflow.jl diff --git a/Testing/workflowCombined.jl b/archive/Testing/workflowCombined.jl similarity index 100% rename from Testing/workflowCombined.jl rename to archive/Testing/workflowCombined.jl diff --git a/Testing/workflowCombinedMulti.jl b/archive/Testing/workflowCombinedMulti.jl similarity index 100% rename from Testing/workflowCombinedMulti.jl rename to archive/Testing/workflowCombinedMulti.jl diff --git a/archive/installPackages.jl b/archive/installPackages.jl new file mode 100755 index 0000000..75e7849 --- /dev/null +++ b/archive/installPackages.jl @@ -0,0 +1,35 @@ +# Pkg module for managing packages +using Pkg +# Define the path to the file containing required package names +package_file = "/data/miraldiNB/Michael/Scripts/GRN/InferelatorJL/Packages.txt" + +# Read package names from the file +required_packages = [] +if isfile(package_file) + # Open the file and read package names + open(package_file) do file + for line in eachline(file) + line = strip(line) + if !isempty(line) + push!(required_packages, line) + end + end + end +else + println("Error: Package requirements file '$package_file' not found.") + exit(1) +end + +# deps = keys(Pkg.project().dependencies) + +# Check and install required packages +for pkg in required_packages + # if !(pkg in deps) + try + eval(Meta.parse("using $pkg")) # check if package already exist by loading + catch + println("Installing $pkg...") #install package if not installed + Pkg.add(pkg) + eval(Meta.parse("using $pkg")) # Load the package after installation + end +end diff --git a/julia_fxns/CombineTRN/.RData b/archive/julia_fxns/CombineTRN/.RData similarity index 100% rename from julia_fxns/CombineTRN/.RData rename to archive/julia_fxns/CombineTRN/.RData diff --git a/julia_fxns/CombineTRN/.Rhistory b/archive/julia_fxns/CombineTRN/.Rhistory similarity index 100% rename from julia_fxns/CombineTRN/.Rhistory rename to archive/julia_fxns/CombineTRN/.Rhistory diff --git a/julia_fxns/CombineTRN/combineTRN1.jl b/archive/julia_fxns/CombineTRN/combineTRN1.jl similarity index 100% rename from julia_fxns/CombineTRN/combineTRN1.jl rename to archive/julia_fxns/CombineTRN/combineTRN1.jl diff --git a/julia_fxns/CombineTRN/combineTRN2.jl b/archive/julia_fxns/CombineTRN/combineTRN2.jl similarity index 100% rename from julia_fxns/CombineTRN/combineTRN2.jl rename to archive/julia_fxns/CombineTRN/combineTRN2.jl diff --git a/julia_fxns/buildTRNs_mLassoStARS.jl b/archive/julia_fxns/buildTRNs_mLassoStARS.jl similarity index 100% rename from julia_fxns/buildTRNs_mLassoStARS.jl rename to archive/julia_fxns/buildTRNs_mLassoStARS.jl diff --git a/julia_fxns/calcAupr.jl b/archive/julia_fxns/calcAupr.jl similarity index 100% rename from julia_fxns/calcAupr.jl rename to archive/julia_fxns/calcAupr.jl diff --git a/julia_fxns/calcPRinfTRNs.jl b/archive/julia_fxns/calcPRinfTRNs.jl similarity index 100% rename from julia_fxns/calcPRinfTRNs.jl rename to archive/julia_fxns/calcPRinfTRNs.jl diff --git a/julia_fxns/calcR2predFromStabilities.jl b/archive/julia_fxns/calcR2predFromStabilities.jl similarity index 100% rename from julia_fxns/calcR2predFromStabilities.jl rename to archive/julia_fxns/calcR2predFromStabilities.jl diff --git a/julia_fxns/estimateInstabilitiesTRNbStARS.jl b/archive/julia_fxns/estimateInstabilitiesTRNbStARS.jl similarity index 100% rename from julia_fxns/estimateInstabilitiesTRNbStARS.jl rename to archive/julia_fxns/estimateInstabilitiesTRNbStARS.jl diff --git a/julia_fxns/getMLassoStARSinstabilitiesPerGeneAndNet.jl b/archive/julia_fxns/getMLassoStARSinstabilitiesPerGeneAndNet.jl similarity index 100% rename from julia_fxns/getMLassoStARSinstabilitiesPerGeneAndNet.jl rename to archive/julia_fxns/getMLassoStARSinstabilitiesPerGeneAndNet.jl diff --git a/julia_fxns/getMLassoStARSlambdaRangePerGene.jl b/archive/julia_fxns/getMLassoStARSlambdaRangePerGene.jl similarity index 100% rename from julia_fxns/getMLassoStARSlambdaRangePerGene.jl rename to archive/julia_fxns/getMLassoStARSlambdaRangePerGene.jl diff --git a/julia_fxns/groupSelection.jl b/archive/julia_fxns/groupSelection.jl similarity index 100% rename from julia_fxns/groupSelection.jl rename to archive/julia_fxns/groupSelection.jl diff --git a/julia_fxns/importGeneExpGeneLists.jl b/archive/julia_fxns/importGeneExpGeneLists.jl similarity index 100% rename from julia_fxns/importGeneExpGeneLists.jl rename to archive/julia_fxns/importGeneExpGeneLists.jl diff --git a/julia_fxns/integratePrior_estTFA.jl b/archive/julia_fxns/integratePrior_estTFA.jl similarity index 100% rename from julia_fxns/integratePrior_estTFA.jl rename to archive/julia_fxns/integratePrior_estTFA.jl diff --git a/julia_fxns/mergeDegeneratePriorTFs.jl b/archive/julia_fxns/mergeDegeneratePriorTFs.jl similarity index 100% rename from julia_fxns/mergeDegeneratePriorTFs.jl rename to archive/julia_fxns/mergeDegeneratePriorTFs.jl diff --git a/julia_fxns/partialCorrelation.jl b/archive/julia_fxns/partialCorrelation.jl similarity index 100% rename from julia_fxns/partialCorrelation.jl rename to archive/julia_fxns/partialCorrelation.jl diff --git a/julia_fxns/plotConfidenceDistribution.jl b/archive/julia_fxns/plotConfidenceDistribution.jl similarity index 100% rename from julia_fxns/plotConfidenceDistribution.jl rename to archive/julia_fxns/plotConfidenceDistribution.jl diff --git a/julia_fxns/plotMetricUtils.jl b/archive/julia_fxns/plotMetricUtils.jl similarity index 100% rename from julia_fxns/plotMetricUtils.jl rename to archive/julia_fxns/plotMetricUtils.jl diff --git a/julia_fxns/priorUtils.jl b/archive/julia_fxns/priorUtils.jl similarity index 100% rename from julia_fxns/priorUtils.jl rename to archive/julia_fxns/priorUtils.jl diff --git a/julia_fxns/scratch/mergeDegeneratePriorTFs.jl b/archive/julia_fxns/scratch/mergeDegeneratePriorTFs.jl similarity index 100% rename from julia_fxns/scratch/mergeDegeneratePriorTFs.jl rename to archive/julia_fxns/scratch/mergeDegeneratePriorTFs.jl diff --git a/julia_fxns/updatePenaltyMatrix.jl b/archive/julia_fxns/updatePenaltyMatrix.jl similarity index 100% rename from julia_fxns/updatePenaltyMatrix.jl rename to archive/julia_fxns/updatePenaltyMatrix.jl diff --git a/archive/packages01.txt b/archive/packages01.txt new file mode 100755 index 0000000..c3e7336 --- /dev/null +++ b/archive/packages01.txt @@ -0,0 +1,148 @@ +Adapt [79e6a3ab-5dfb-504d-930d-738a2a938a0e] +AliasTables [66dad0bd-aa9a-41b7-9441-69ab47430ed8] +ArgParse [c7e460c6-2fb9-53a9-8c5b-16f535851c63] +ArgTools [0dad84c5-d112-42e6-8d28-ef12dabb789f] +Arrow [69666777-d1a9-59fb-9406-91d4454c9d45] +ArrowTypes [31f734f8-188a-4ce0-8406-c8a06bd891cd] +Artifacts [56f22d72-fd6d-98f1-02f0-08ddc0907c33] +AxisAlgorithms [13072b0f-2c55-5437-9ae7-d433b7a33950] +Base [top-level] +Base64 [2a0f44e3-6c83-55bd-87e4-b1978d98bd5f] +BitIntegers [c3b6d118-76ef-56ca-8cc7-ebb389d030a1] +CRC32c [8bf52ea8-c179-5cab-976a-9e18b702a9bc] +CSV [336ed68f-0bac-5ca0-87d4-7b16caf5d00b] +CategoricalArrays [324d7699-5711-5eae-9e2f-1d82baa6b597] +ChainRulesCore [d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4] +ChangesOfVariables [9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0] +CodeTracking [da1fd8a2-8d9e-5ec2-8556-3022fb5608a2] +CodecLz4 [5ba52731-8f18-5e0d-9241-30f10d1ec561] +CodecZlib [944b1d66-785c-5afd-91f1-9de20f533193] +CodecZstd [6b39b394-51ab-5f42-8807-6242bab2b4c2] +ColorTypes [3da002f7-5984-5a60-b8a6-cbb66c0b333f] +Colors [5ae59095-9a9b-59fe-a467-6f913c188581] +Combinatorics [861a8166-3701-5b0c-9a16-15d98fcdc6aa] +Compat [34da2185-b29b-5c13-b0c7-acf172513d20] +CompilerSupportLibraries_jll [e66e0078-7015-5450-92f7-15fbd957f2ae] +ConcurrentUtilities [f0e56b4a-5159-44fe-b623-3e5288b988bb] +Conda [8f4d0f93-b110-5947-807f-2305c1781a2d] +Core [top-level] +Crayons [a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f] +DataAPI [9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a] +DataFrames [a93c6f00-e57d-5684-b7b6-d8193f3e46c0] +DataStructures [864edb3b-99cc-5e75-8d2d-829cb0a9cfe8] +DataValueInterfaces [e2d170a0-9d28-54be-80f0-106bbe20a464] +Dates [ade2ca70-3891-5945-98fb-dc099432e06a] +DelimitedFiles [8bb1440f-4735-579b-a4ab-409b98df4dab] +DensityInterface [b429d917-457f-4dbc-8f4c-0cc954292b1d] +Distributed [8ba89e20-285c-5b6f-9357-94700520ee1b] +Distributions [31c24e10-a181-5473-b8eb-7969acd0382f] +DocStringExtensions [ffbed154-4ef7-542d-bbb7-c09d3a79fcae] +Downloads [f43a241f-c20a-4ad4-852c-f6b1247861c6] +EnumX [4e289a0a-7415-4d19-859d-a7e5c4648b56] +ExprTools [e2ba6199-217a-4e67-a87a-7c52f15ade04] +FileIO [5789e2e9-d7fb-5bc7-8068-2c6fae9b9549] +FilePathsBase [48062228-2e41-5def-b9a4-89aafe57970f] +FileWatching [7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee] +FillArrays [1a297f60-69ca-5386-bcde-b61e274b549b] +FixedPointNumbers [53c48c17-4a7d-5ca2-90c5-79b7896eea93] +Future [9fa8497b-333b-5362-9e8d-4d0656e87820] +GLMNet [8d5ece8b-de18-5317-b113-243142960cc6] +HypergeometricFunctions [34004b35-14d8-5ef3-9330-4cdb6864b03a] +InlineStrings [842dd82b-1e85-43dc-bf29-5d0ee9dffc48] +InteractiveUtils [b77e0a4c-d291-57a0-90e8-8db25a27a240] +Interpolations [a98d9a8b-a2ab-59e6-89dd-64a1c18fca59] +InverseFunctions [3587e190-3f89-42d0-90ee-14403ec27112] +InvertedIndices [41ab1584-1d38-5bbf-9106-f11c6c58b48f] +IrrationalConstants [92d709cd-6900-40b7-9082-c6be49f344b6] +IteratorInterfaceExtensions [82899510-4779-5014-852e-03e436cf321d] +JLD2 [033835bb-8acc-5ee8-8aae-3f567f8a3819] +JLLWrappers [692b3bcd-3c85-4b1f-b108-f13ce0eb3210] +JSON [682c06a0-de6a-54ab-a142-c8b1cf79cde6] +JuliaInterpreter [aa1ae85d-cabe-5617-a682-6adf51b2e16a] +LaTeXStrings [b964fa9f-0449-5b57-a5c2-d3ea65f4040f] +LazyArtifacts [4af54fe1-eca0-43a8-85a7-787d91b784e3] +LibCURL [b27032c2-a3e7-50c8-80cd-2d36dbcbfd21] +LibCURL_jll [deac9b47-8bc7-5906-a0fe-35ac56dc84c0] +LibGit2 [76f85450-5226-5b5a-8eaa-529ad045b433] +Libdl [8f399da3-3557-5675-b5ff-fb832c97cbdb] +LinearAlgebra [37e2e46d-f89d-539d-b4ee-838fcccc9c8e] +LogExpFunctions [2ab3a3ac-af41-5b50-aa03-7779005ae688] +Logging [56ddb016-857b-54e1-b83d-db4d58db5568] +LoweredCodeUtils [6f1432cf-f94c-5a45-995e-cdbf5db27b0b] +Lz4_jll [5ced341a-0733-55b8-9ab6-a4889d929147] +MacroTools [1914dd2f-81c6-5fcd-8719-6d5c9610ff09] +Main [top-level] +Markdown [d6f4376e-aef5-505a-96c1-9c027394607a] +Measures [442fdcdd-2543-5da2-b0f3-8c86c306513e] +Missings [e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28] +Mmap [a63ad114-7e13-5084-954f-fe012c677804] +Mocking [78c3b35d-d492-501b-9361-3d52fe80e533] +MozillaCACerts_jll [14a3606d-f60d-562e-9121-12d972cd8159] +NamedArrays [86f7a689-2022-50b4-a561-43c23ac3c673] +NetworkOptions [ca575930-c2e3-43a9-ace4-1e988b2c1908] +OffsetArrays [6fe1bfb0-de20-5000-8ca7-80f57d26f881] +OpenLibm_jll [05823500-19ac-5b8b-9628-191a04bc5112] +OpenSSL_jll [458c3c95-2e84-50aa-8efc-19380b2a3a95] +OpenSpecFun_jll [efe28fd5-8261-553b-a9e1-b2916fc3738e] +OrderedCollections [bac558e1-5e72-5ebc-8fee-abe8a469f55d] +PDMats [90014a1f-27ba-587c-ab20-58faa44d9150] +Parsers [69de0a69-1ddd-5017-9359-2bf0b02dc9f0] +Pkg [44cfe95a-1eb2-52ea-b672-e2afdf69b78f] +PooledArrays [2dfb63ee-cc39-5dd5-95bd-886bf059d720] +PrecompileTools [aea7be01-6a6a-4083-8856-8a6e6704d82a] +Preferences [21216c6a-2e73-6563-6e65-726566657250] +PrettyTables [08abe8d2-0d0c-5749-adfa-8a2ac140af0d] +Printf [de0858da-6303-5e67-8744-51eddeeeb8d7] +Profile [9abbd945-dff8-562f-b5e8-e1ebf5ef1b79] +ProgressBars [49802e3a-d2f1-5c88-81d8-b72133a6f568] +PtrArrays [43287f4e-b6f4-7ad1-bb20-aadabca52c3d] +PyCall [438e738f-606a-5dbb-bf0a-cddfbfd45ab0] +PyPlot [d330b81b-6aea-500a-939a-2ce795aea3ee] +QuadGK [1fd47b50-473d-5c70-9696-f719f8f3bcdc] +REPL [3fa0cd96-eef1-5676-8a61-b3b8758bbffb] +Random [9a3f8284-a2c9-5f02-9a11-845980a1fd5c] +Ratios [c84ed2f1-dad5-54f0-aa8e-dbefe2724439] +RecipesBase [3cdcf5f2-1ef4-517c-9805-6587b60abb01] +Reexport [189a3867-3050-52da-a836-e630ba90ab69] +Requires [ae029012-a4dd-5104-9daa-d747884805df] +Revise [295af30f-e4ad-537b-8983-00126c2a3abe] +Rmath [79098fc4-a85e-5d69-aa6a-4863f24498fa] +Rmath_jll [f50d1b31-88e8-58de-be2c-1cc44531875f] +SHA [ea8e919c-243c-51af-8825-aaa63cd721ce] +Scratch [6c6a2e73-6563-6170-7368-637461726353] +SentinelArrays [91c51154-3ec4-41a3-a24f-3f23e20d615c] +Serialization [9e88b42a-f829-5b0c-bbe9-9e923198166b] +SharedArrays [1a1011a3-84de-559e-8e89-a11a2f7dc383] +Sockets [6462fe0b-24de-5631-8697-dd941f90decc] +SortingAlgorithms [a2af1166-a08f-5f64-846c-94a0d3cef48c] +SparseArrays [2f01184e-e22b-5df5-ae63-d93ebab69eaf] +SpecialFunctions [276daf66-3868-5448-9aa4-cd146d93841b] +StaticArrays [90137ffa-7385-5640-81b9-e52037218182] +StaticArraysCore [1e83bf80-4336-4d27-bf5d-d5a4f845583c] +Statistics [10745b16-79ce-11e8-11f9-7d13ad32a3b2] +StatsAPI [82ae8749-77ed-4fe6-ae5f-f523153014b0] +StatsBase [2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91] +StatsFuns [4c63d2b9-4356-54db-8cca-17b64c39e42c] +StringManipulation [892a3eda-7b42-436c-8928-eab12a02cf0e] +SuiteSparse [4607b0f0-06f3-5cda-b6b1-a6196a1729e9] +TOML [fa267f1f-6049-4f14-aa54-33bafae1ed76] +TZJData [dc5dba14-91b3-4cab-a142-028a31da12f7] +TableTraits [3783bdb8-4a98-5b6b-af9a-565f29a5fe9c] +Tables [bd369af6-aec1-5ad0-b16a-f7cc5008161c] +Tar [a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e] +Test [8dfed614-e22c-5e08-85e1-65c5234f0b40] +TextWrap [b718987f-49a8-5099-9789-dcd902bef87d] +TickTock [9ff05d80-102d-5586-aa04-3a8bd1a90d20] +TimeZones [f269a46b-ccf7-5d73-abea-4c690281aa53] +TranscodingStreams [3bb67fe8-82b1-5028-8e26-92a6c54297fa] +UUIDs [cf7118a7-6976-5b1a-9a39-7adc72f591a4] +Unicode [4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5] +VersionParsing [81def892-9a0e-5fdd-b105-ffc91e053289] +WeakRefStrings [ea10d353-3f73-51f8-a26c-33c1cb351aa5] +WoodburyMatrices [efce3f68-66dc-5838-9240-27a6d6f5f9b6] +WorkerUtilities [76eceee3-57b5-4d4a-8e66-0e911cebbf60] +Zlib_jll [83775a58-1f1d-513f-b197-d71354ab007a] +Zstd_jll [3161d3a3-bdf6-5164-811a-617609db77b4] +glmnet_jll [78c6b45d-5eaf-5d68-bcfb-a5a2cb06c27f] +nghttp2_jll [8e850ede-7688-5339-a07c-302acd2aaf8d] +p7zip_jll [3f19e933-33d8-53b3-aaab-bd5110c3b7a0] diff --git a/evaluation/R/ComparePrior.R b/evaluation/R/ComparePrior.R new file mode 100755 index 0000000..e0163aa --- /dev/null +++ b/evaluation/R/ComparePrior.R @@ -0,0 +1,33 @@ +library(ggplot2) + +prior1 <- read.table("/data/miraldiNB/Katko/Projects/Barski_CD4_Multiome/Outs/TRAC_loop/Prior/Prior_sum.tsv") +prior2 <- read.table("/data/miraldiNB/Katko/Projects/Barski_CD4_Multiome/Outs/Seurat/Prior/MEMT_050723_FIMOp5_b.tsv") + +index <- intersect(rownames(prior1), rownames(prior2)) +prior1 <- prior1[index,] +prior2 <- prior2[index,] + +prior_df <- as.matrix(rep(0, length(colnames(prior1)))) +prior_df <- cbind(prior_df, rep(0, length(colnames(prior2)))) +rownames(prior_df) <- colnames(prior1) +colnames(prior_df) <- c("TRAC","Body") + +for(i in 1:length(colnames(prior1))){ + prior1_targets <- length(which(prior1[,i] > 0)) + prior2_targets <- length(which(prior2[,i] > 0)) + prior_df[i,1] <- prior1_targets + prior_df[i,2] <- prior2_targets +} + +prior_df <- as.data.frame(prior_df) +pdf("Prior_Scatter.pdf", width = 8, height = 8) +ggplot(prior_df, aes(x = TRAC, y = Body)) + geom_point() + geom_abline(slope=1, intercept=0) +dev.off() + +prior_df_hist <- data.frame(Freq = c(prior_df[,1], prior_df[,2])) +prior_df_hist$TF <- c(rownames(prior_df), rownames(prior_df)) +prior_df_hist$Prior <- c(rep("Prior1", length(colnames(prior1))), rep("Prior2", length(colnames(prior2)))) + +pdf("Prior_hist.pdf", width = 8, height = 8) +ggplot(prior_df_hist, aes(x = Freq, color = Prior)) + geom_histogram(fill = "white", position = "identity", alpha = 0.7) +dev.off() \ No newline at end of file diff --git a/evaluation/R/DeltaTFA.R b/evaluation/R/DeltaTFA.R new file mode 100755 index 0000000..50ac71d --- /dev/null +++ b/evaluation/R/DeltaTFA.R @@ -0,0 +1,376 @@ +library(biomaRt) +library(stringr) +library(DESeq2) + + +counts <- read.table("GSE271788_dedup_counts.txt", header = T) +rownames(counts) <- counts[,1] +counts <- counts[,7:length(counts)] + +new_names <- str_extract(colnames(counts), "Donor_\\d+_[A-Za-z0-9]+") +new_names <- sapply(strsplit(new_names, "_"), function(x) paste(x[3], x[1], x[2], sep = "_")) +colnames(counts) <- new_names + +ensembl_ids_clean <- sub("\\..*", "", rownames(counts)) +gtf <- import("gencode.v48.chr_patch_hapl_scaff.annotation.gtf.gz") +# Keep only gene entries +genes <- gtf[gtf$type == "gene"] + +# Build mapping table +mapping <- data.frame( + ensembl_gene_id = gsub("\\..*", "", genes$gene_id), # remove version + gene_name = genes$gene_name +) %>% + distinct() +symbol_map <- mapping$gene_name[match(gene_ids, mapping$ensembl_gene_id)] +symbol_map[is.na(symbol_map)] <- gene_ids[is.na(symbol_map)] + + +gene_ids <- gsub("\\..*", "", rownames(counts)) # remove versions +symbol_map <- mapping$gene_name[match(gene_ids, mapping$ensembl_gene_id)] +rownames(counts) <- ifelse(is.na(symbol_map), gene_ids, symbol_map) +counts <- counts[!grepl("^ENSG", rownames(counts)), ] + +# Build metadata dataframe +colnames(counts) <- make.unique(colnames(counts)) + +clean_names <- sub("\\.\\d+$", "", colnames(counts)) + +# Split by underscore +split_info <- do.call(rbind, strsplit(clean_names, "_")) + +# Build metadata dataframe +metadata <- data.frame( + TF = split_info[, 1], + Donor = paste(split_info[, 2], split_info[, 3], sep = "_"), + row.names = colnames(counts), + stringsAsFactors = FALSE +) + +dds <- DESeqDataSetFromMatrix( + countData = counts, + colData = metadata, + design = ~ Donor + TF +) +dds <- DESeq(dds) + +#cor_matrix <- dcast(network, Gene ~ TF, value.var = "signedQuantile", fill = 0) + + +### Plot TFA changes +library(tidyverse) +bulk_cor <- read.table("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/FullPseudobulk/FullPseudobulk/lambda0p5_200totSS_20tfsPerGene_subsamplePCT63/Combined/TFA_cor.txt") +bulk_quant <- read.table("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/FullPseudobulk/FullPseudobulk/lambda0p5_200totSS_20tfsPerGene_subsamplePCT63/Combined/TFA_quant.txt") +bulk_sign <- read.table("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/FullPseudobulk/FullPseudobulk/lambda0p5_200totSS_20tfsPerGene_subsamplePCT63/Combined/TFA_sign.txt") +sc_cor <- read.table("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/scSubsampleFraction/lambda0p5_200totSS_20tfsPerGene_subsamplePCT10/Combined/TFA_cor.txt") +sc_quant <- read.table("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/scSubsampleFraction/lambda0p5_200totSS_20tfsPerGene_subsamplePCT10/Combined/TFA_quant.txt") +sc_sign <- read.table("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/scSubsampleFraction/lambda0p5_200totSS_20tfsPerGene_subsamplePCT10/Combined/TFA_binary.txt") +tfa_matrix <- list(bulk_cor, bulk_quant, bulk_sign, sc_cor, sc_quant, sc_sign) +names(tfa_matrix) <- c("bulk_cor","bulk_quant","bulk_sign","sc_cor","sc_quant","sc_sign") + +tf_list <- c( + "AIRE", "BACH2", "BCL11B", "BPTF", "CLOCK", "EPAS1", "ETS1", "FOXK1", "FOXP1", "FOXP3", + "GATA3", "GFI1", "HIVEP2", "IKZF1", "IRF1", "IRF2", "IRF4", "IRF7", "IRF9", "KLF2", + "KMT2A", "MBD2", "MYB", "NFE2L2", "NFAT5", "NFKB1", "NFKB2", "RELA", + "RELB", "REL", "RFX5", "RORC", "SETDB1", "SREBF1", "STAT1", "STAT2", "STAT3", "STAT5A", + "STAT5B", "TBX21", "TCF3", "TP53", "YBX1", "YBX3", "YY1", "ZBTB14", "ZFP3", "ZKSCAN1", + "ZNF329", "ZNF341", "ZNF791" +) +tf_list <- tf_list[which(tf_list %in% rownames(tfa_matrix[[1]]))] +library(dplyr) +library(tidyr) +library(ggplot2) +library(purrr) + + +analyse_one_matrix <- function(tfa_matrix, + tf_list, + pdf_name, + sig_cutoffs = c(`***` = 0.001, + `**` = 0.01, + `*` = 0.05)) { + + p_to_symbol <- function(p) { + stars <- names(sig_cutoffs)[which(p < sig_cutoffs)] + if (length(stars) == 0) "ns" else stars[1] + } + + res <- purrr::map_dfr(tf_list, function(tf) { + + #--- pull the row safely --------------------------------------------------- + if (!tf %in% rownames(tfa_matrix)) { + warning("TF ", tf, " not present in matrix; skipping.") + return(tibble::tibble(TF = tf, p.value = NA_real_, + signif = "NA", direction = "NA", plot = list(NULL))) + } + + tf_activity <- tfa_matrix[tf, ] # numeric vector + tf_df <- tibble::tibble(Sample = names(tf_activity), + Activity = as.numeric(tf_activity)) %>% + dplyr::mutate( + Knockout_TF = sub("_Donor_.*", "", Sample), + Donor = sub(".*_Donor_", "", Sample) %>% sub("\\..*", "", .) + ) + + #--- average per donor ----------------------------------------------------- + controls <- tf_df %>% dplyr::filter(Knockout_TF == "AAVS1") %>% + dplyr::group_by(Donor) %>% + dplyr::summarise(Control = mean(Activity), .groups = "drop") + + knockouts <- tf_df %>% dplyr::filter(Knockout_TF == tf) %>% + dplyr::group_by(Donor) %>% + dplyr::summarise(Knockout = mean(Activity), .groups = "drop") + + paired <- dplyr::inner_join(controls, knockouts, by = "Donor") + + if (nrow(paired) < 2) { + warning("TF ", tf, ": fewer than 2 paired donors – skipped") + return(tibble::tibble(TF = tf, p.value = NA_real_, + signif = "NA", direction = "NA", plot = list(NULL))) + } + + #--- stats ----------------------------------------------------------------- + t_res <- stats::t.test(paired$Knockout, paired$Control, paired = TRUE) + p_val <- t_res$p.value + direction <- ifelse(mean(paired$Knockout - paired$Control) > 0, "Up", "Down") + signif_sym <- p_to_symbol(p_val) + + #--- plot ------------------------------------------------------------------ + long <- paired %>% + tidyr::pivot_longer(Control:Knockout, + names_to = "Condition", + values_to = "Activity") + + g <- ggplot2::ggplot(long, ggplot2::aes(Condition, Activity, group = Donor)) + + ggplot2::geom_line(ggplot2::aes(color = Donor), linewidth = 1.1, alpha = 0.8) + + ggplot2::geom_point(ggplot2::aes(color = Donor), size = 3) + + ggplot2::theme_minimal(base_size = 14) + + ggplot2::labs( + title = paste(tf, "Activity Before and After Knockout"), + subtitle = paste0("p = ", signif(p_val, 3), + " (", signif_sym, ", ", direction, ")"), + x = "Condition", y = "Estimated TF Activity", color = "Donor") + + ggplot2::theme(plot.title = ggplot2::element_text(hjust = 0.5, face = "bold"), + plot.subtitle = ggplot2::element_text(hjust = 0.5)) + + tibble::tibble(TF = tf, + p.value = p_val, + signif = signif_sym, + direction = direction, + plot = list(g)) + }) + + #--- write one PDF with all TFs --------------------------------------------- + grDevices::pdf(pdf_name, width = 8, height = 10) + purrr::walk(res$plot[!purrr::map_lgl(res$plot, is.null)], print) + grDevices::dev.off() + + res # keep the plot column this time +} + +################################################################ +## 2. MANY-MATRIX WRAPPER (robust, no imap_dfr required) ## +################################################################ +analyse_many_matrices <- function(tfa_list, # named list or named file vector + tf_list, + out_dir = ".") { + + # read .rds files if character vector of paths was supplied + if (is.character(tfa_list) && !is.matrix(tfa_list[[1]])) { + tfa_list <- lapply(tfa_list, readRDS) + names(tfa_list) <- basename(names(tfa_list)) # keep vector names + } + + if (is.null(names(tfa_list)) || any(names(tfa_list) == "")) + stop("tfa_list must be a *named* list or vector so datasets can be labelled.") + + dir.create(out_dir, showWarnings = FALSE, recursive = TRUE) + + results <- vector("list", length(tfa_list)) + out_i <- 1L + + for (mat_name in names(tfa_list)) { + + mat <- tfa_list[[mat_name]] + if (!(is.matrix(mat) || is.data.frame(mat))) { + warning("Skipping '", mat_name, "': not a matrix/data.frame") + next + } + + pdf_file <- file.path(out_dir, + paste0("TF_activity_changes_", mat_name, ".pdf")) + + message("Processing ", mat_name, " …") + stats_one <- tryCatch( + analyse_one_matrix(mat, tf_list, pdf_name = pdf_file), + error = function(e) { + warning("Failed on ", mat_name, ": ", conditionMessage(e)) + NULL + } + ) + + if (is.null(stats_one)) next # skip failures + + stats_one <- tibble::as_tibble(stats_one) # force tibble + stats_one <- dplyr::mutate(stats_one, dataset = mat_name, .before = 1) + + results[[out_i]] <- stats_one + out_i <- out_i + 1L + } + + dplyr::bind_rows(results[seq_len(out_i - 1L)]) +} + + +all_stats <- analyse_many_matrices(tfa_matrix, + tf_list, + out_dir = "plots") # PDFs in ./plots + +## View or export the combined statistics: +print(all_stats) +library(ComplexHeatmap) +library(circlize) +library(grid) +library(dplyr) +library(tidyr) + +pct_expressing <- function(obj, + genes, + assay = "RNA", + slot = "data") { + + expr <- Seurat::GetAssayData(obj[[assay]], slot = slot) # genes × cells + vec <- vapply(genes, function(g) { + if (!g %in% rownames(expr)) return(NA_real_) + mean(expr[g, ] > 0) * 100 # % of cells + }, numeric(1)) + names(vec) <- genes + vec +} + +DefaultAssay(obj) <- "RNA" # if not already set +pct_vec <- pct_expressing(obj, tf_list) + +score_mat <- all_stats %>% + mutate(score = ifelse( + is.na(p.value), + NA_real_, + -log10(p.value) * ifelse(direction == "Up", 1, -1)) + ) %>% + select(TF, dataset, score) %>% + pivot_wider(names_from = dataset, values_from = score) %>% + as.data.frame() + +rownames(score_mat) <- score_mat$TF +score_mat <- as.matrix(score_mat[, -1, drop = FALSE]) + +star_mat <- all_stats %>% + mutate(stars = case_when( + is.na(p.value) ~ "", + p.value < 0.001 ~ "***", + p.value < 0.01 ~ "**", + p.value < 0.10 ~ "*", + TRUE ~ "" + )) %>% + select(TF, dataset, stars) %>% + pivot_wider(names_from = dataset, values_from = stars) %>% + as.data.frame() + +rownames(star_mat) <- star_mat$TF +star_mat <- as.matrix(star_mat[, -1, drop = FALSE]) + +deg_counts <- setNames(rep(0, nrow(score_mat)), rownames(score_mat)) +deg_counts[names(target_number)] <- target_number # overwrite where present + +# ---- LEFT annotation: % of cells expressing ---------------------------------- +row_anno_left <- rowAnnotation( + `% Cells\nExpressing` = anno_barplot( + pct_vec[rownames(score_mat)], # ensure correct order + gp = gpar(fill = "steelblue", col = NA), + bar_width = 0.85, + border = FALSE, + axis_param = list(at = c(0, 50, 100), gp = gpar(fontsize = 7)) + ), + annotation_name_side = "top", + annotation_name_gp = gpar(fontsize = 9, fontface = "bold"), + width = unit(2.4, "cm") +) + +# ---- RIGHT annotation: # of DE genes ----------------------------------------- +row_anno_right <- rowAnnotation( + `# DE genes` = anno_barplot( + deg_counts, + gp = gpar(fill = "grey40", col = NA), + bar_width = 0.85, + border = FALSE, + axis_param = list(at = c(0, max(deg_counts)), gp = gpar(fontsize = 7)) + ), + annotation_name_side = "top", + annotation_name_gp = gpar(fontsize = 9, fontface = "bold"), + width = unit(2.3, "cm") +) + +max_abs <- max(abs(score_mat), na.rm = TRUE) +col_fun <- circlize::colorRamp2(c(-max_abs, 0, max_abs), + c("navy", "white", "firebrick")) +text_col <- function(fill) { + rgb <- col2rgb(fill) / 255 + ifelse(0.299*rgb[1] + 0.587*rgb[2] + 0.114*rgb[3] < 0.5, "white", "black") +} + +row_labels <- rowAnnotation( + TF = anno_text( + rownames(score_mat), + gp = gpar(fontsize = 9), + just = "left", + location = 0.5 # centred vertically in each cell + ), + width = unit(2.2, "cm"), # reserve space for the longest name + show_annotation_name = FALSE +) + +############################################################################### +## 2. Core heat-map (row names turned OFF) ################################## +############################################################################### +ht_core <- Heatmap( + score_mat, + name = "-log10(p)", + col = col_fun, + na_col = "grey90", + border = TRUE, + row_km = 3, + + show_row_names = FALSE, # <── row names handled by row_labels + column_names_gp = gpar(fontsize = 9), + + heatmap_legend_param = list( + title_gp = gpar(fontsize = 10, fontface = "bold"), + labels_gp = gpar(fontsize = 9) + ), + cell_fun = function(j, i, x, y, width, height, fill) { + lab <- star_mat[i, j] + if (nzchar(lab)) { + grid.text( + lab, x, y, + gp = gpar(fontsize = 8, + fontface = "bold", + col = text_col(fill)) + ) + } + } +) + +############################################################################### +## 3. Assemble LEFT + CORE + LABELS + RIGHT ########################## +############################################################################### +ht_full <- row_anno_left + ht_core + row_labels + row_anno_right +# %-expressing matrix TF names #DE genes + +############################################################################### +## 4. Draw / save ########################################################### +############################################################################### +pdf("/data/miraldiNB/Katko/Projects/Julia/Inferelator_Julia/outputs/subNetworks/scSubsampleFraction/lambda0p5_200totSS_20tfsPerGene_subsamplePCT10/Combined/plots//TF_heatmap_with_two_bars_and_labels.pdf", width = 7.5, height = 9) +draw(ht_full, + heatmap_legend_side = "right", + annotation_legend_side = "right") +dev.off() \ No newline at end of file diff --git a/evaluation/R/GRN_Tfh10_PR_heatmap.R b/evaluation/R/GRN_Tfh10_PR_heatmap.R new file mode 100755 index 0000000..83f420e --- /dev/null +++ b/evaluation/R/GRN_Tfh10_PR_heatmap.R @@ -0,0 +1,444 @@ +# GRN_Tfh10_PR_heatmap_final2.R +# Heatmap of PR results +rm(list=ls()) +options(stringsAsFactors=FALSE) +set.seed(42) +suppressPackageStartupMessages({ +library(ggplot2) +library(scales) +}) +source('/data/miraldiNB/Michael/Scripts/GRN/utils_grn_analysis.R') +source('/data/miraldiNB/wayman/scripts/net_utils.R') +source('/data/miraldiNB/wayman/scripts/clustering_utils.R') + +####### INPUTS ####### + +# Output directory +dir_out <- '/data/miraldiNB/Michael/mCD4T_Wayman/Figures/Tfh10_S4B/2p5PCT' +dir.create(dir_out, showWarnings=FALSE, recursive=TRUE) + +# File save name +file_save <- 'Tfh10_S4' + +# Gene target file +file_targ <- "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/targRegLists/targetGenes_names.txt" + +# TF list +file_tf <- "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/targRegLists/potRegs_names.txt" + +# list of networks (in order) +file_grn <- c( + +# No Prior : lambdaBias=1 + tfaOpt='_TFmRNA' + ATAC_KOprior +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # NoPrior (TFmRNA + ATAC) + +# ATAC Prior +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # +mRNA (ATAC) +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # ++mRNA (ATAC) + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # TFA +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # +TFA (ATAC) +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # ++TFA (ATAC) +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/Combined/combined.tsv", # +Combine (ATAC) + + +# ATAC Prior (sc-Inferelator) +# + +# ATAC+ChIP Prior +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # +mRNA (ATAC + ChIP) +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # ++mRNA (ATAC + ChIP) +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # TFA +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # +TFA (ATAC + ChIP) +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # ++TFA (ATAC + ChIP) +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/Combined/combined.tsv", # +Combine (ATAC + ChIP) + +# ATAC + ChIP Prior (sc- Inferelator) +# + +# ATAC+KO Prior +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # +mRNA (ATAC + KO) +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.txt", # ++mRNA (ATAC + KO) +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # TFA +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # +TFA (ATAC + KO) +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.txt", # ++TFA (ATAC + KO) +"/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63/Combined/combined.tsv", # +Combine (ATAC + KO) + +# ATAC + KO Prior (sc- Inferelator) +# + +#Cell Oracle +# "/data/miraldiNB/Michael/mCD4T_Wayman/CellOracle/GRNs/Unprocessed/Combined/Old/max_pCutoff_0.1pct_combined_GRN_3cols.tsv", +"/data/miraldiNB/Michael/mCD4T_Wayman/CellOracle/GRNs/Unprocessed/CombinedGRNs/max/max_pCut0.1_meanEdgesPerGene20_combined_GRN.tsv", + +# Scenic Plus +"/data/miraldiNB/Michael/mCD4T_Wayman/scenicPlus/eRegulons_direct_filtered.tsv", # Direct +"/data/miraldiNB/Michael/mCD4T_Wayman/scenicPlus/eRegulons_extended_filtered.tsv", # Extended + +#SupirFactor +"/data/miraldiNB/Michael/mCD4T_Wayman/SupirFactor/Hierarchical_20250212_2043/GRN_Hierarchical_long.tsv", # Hierarchical +"/data/miraldiNB/Michael/mCD4T_Wayman/SupirFactor/Shallow_20250330_0212/GRN_Shallow_long.tsv" # Shallow +) + +# network names +name_grn <- c( + +'No Prior', + +'ATAC mRNA +', +'ATAC mRNA ++', +'ATAC TFA', +'ATAC TFA +', +'ATAC TFA ++', +'ATAC Combined +', + +# 'sc ATAC mRNA +', +# 'sc ATAC TFA +', +# 'sc ATAC Combined +', + +'ChIP mRNA +', +'ChIP mRNA ++', +'ChIP TFA', +'ChIP TFA +', +'ChIP TFA ++', +'ChIP Combined +', + +# 'sc ChIP mRNA +', +# 'sc ChIP TFA +', +# 'sc ChIP Combined +', + +'KO mRNA +', +'KO mRNA ++', +'KO TFA', +'KO TFA +', +'KO TFA ++', +'KO Combined +', + +# 'sc KO mRNA +', +# 'sc KO TFA +', +# 'sc KO Combined +', + +'Cell Oracle', +# 'Cell Oracle Mean', + +'SCENIC+ Dir', +'SCENIC+ Ext', + +'SupirFactor H', +'SupirFactor Sh' + +) + +# network type +type_grn <- c( + +'No Prior', + +'ATAC', +'ATAC', +'ATAC', +'ATAC', +'ATAC', +'ATAC', + +# 'ATAC', +# 'ATAC', +# 'ATAC', + +'ChIP', +'ChIP', +'ChIP', +'ChIP', +'ChIP', +'ChIP', + +# 'ChIP', +# 'ChIP', +# 'ChIP', + +'KO', +'KO', +'KO', +'KO', +'KO', +'KO', + +# 'KO', +# 'KO', +# 'KO', + +'Cell Oracle', +# 'Cell Oracle', + +'SCENIC+', +'SCENIC+', + +'SupirFactor', +'SupirFactor' +) + +# Gold standard names +name_gs <- c( +'ChIP', +'KO', +'KC' +) + +# List of gold standards +file_gs <- c( + # '/data/miraldiNB/wayman/projects/Tfh10/outs/202204/GRN/GS/TF_binding/priors/FDR5_Rank50/prior_ChIP_Thelper_Miraldi2019Th17_combine_FDR5_Rank50_sp.tsv', + # '/data/miraldiNB/wayman/projects/Tfh10/outs/202204/GRN/GS/RNA/priors/prior_RNA_Thelper_Miraldi2019Th17_combine_Log2FC0p5_FDR20_Rank50_sp.tsv', + # '/data/miraldiNB/wayman/projects/Tfh10/outs/202204/GRN/GS/KC/prior_KC_Thelper_Miraldi2019Th17_Rank100_sp.tsv' + "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/goldStandards/prior_ChIP_Thelper_Miraldi2019Th17_combine_FDR5_Rank50_sp.tsv", + "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/goldStandards/prior_TF_KO_RNA_Thelper_Miraldi2019Th17_combine_Log2FC0p5_FDR20_Rank50_sp.tsv", + "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/goldStandards/prior_KC_Thelper_Miraldi2019Th17_Rank100_sp.tsv" +) + +# gold standard save file name +save_gs <- '' + +# params +min_targ_gs <- 20 # gold standard min target cutoff +cutoff_rec <- 0.025 # recall cutoff +range_fc <- c(-2,2) # color scale range for log2fc heatmap +breaks_fc <- c(-2,-1,0,1,2) +heat_h <- 8 # heatmap height +# heat_w <- 15 # heatmap width +heat_w <- 21 # heatmap width +box_w <- 6 # boxplot width + +###################### + +# Output file label +label_pcut <- 100*cutoff_rec +file_save <- if (nzchar(save_gs)) paste0(file_save, '_', save_gs) else file_save + + +# Load tf list +tf <- readLines(file_tf) + +# Calc precision-recall +df_pr <- NULL +for (ix in 1:length(file_grn)){ + + print(paste0('Process GRN: ',ix,'/',length(file_grn))) + for (jx in 1:length(file_gs)){ + + # # Filter gold standard TFs + gs <- read.delim(file_gs[jx], header=TRUE, sep='\t') + gs <- subset(gs, TF %in% tf) + keep_tf <- names(which(table(gs$TF) >= min_targ_gs)) + gs <- subset(gs, TF %in% keep_tf) + + # All PR + curr_tf <- sort(unique(gs$TF)) + pr <- grn_pr(grn=file_grn[ix], gs=gs, gene_tar=file_targ, tf=curr_tf) + prec <- pr$P; prec <- c(prec, pr$Rand) + recall <- pr$R; recall <- c(recall, 1) + + # Precision at cutoff + slope_pr <- (prec[which(recall >= cutoff_rec)[1]] - prec[which(recall >= cutoff_rec)[1]-1])/ + (recall[which(recall >= cutoff_rec)[1]] - recall[which(recall >= cutoff_rec)[1]-1]) + pcut <- prec[which(recall >= cutoff_rec)[1]-1] + + slope_pr*(cutoff_rec-recall[which(recall >= cutoff_rec)[1]-1]) + curr_df <- data.frame(GRN=name_grn[ix], Type=type_grn[ix], GS=name_gs[jx], + TF='All TF', Pcut=pcut, Rand=pr$Rand) + df_pr <- rbind(df_pr, curr_df) + + for (kx in curr_tf){ + + # pr <- grn_pr(grn=file_grn[ix], gs=file_gs[jx], gene_tar=file_targ, tf=kx) + pr <- grn_pr(grn=file_grn[ix], gs=gs, gene_tar=file_targ, tf=kx) + if (!is.null(pr)){ + prec <- pr$P; prec <- c(prec, pr$Rand) + recall <- pr$R; recall <- c(recall, 1) + + # precision at cutoff + slope_pr <- (prec[which(recall >= cutoff_rec)[1]] - prec[which(recall >= cutoff_rec)[1]-1])/ + (recall[which(recall >= cutoff_rec)[1]] - recall[which(recall >= cutoff_rec)[1]-1]) + pcut <- prec[which(recall >= cutoff_rec)[1]-1] + + slope_pr*(cutoff_rec-recall[which(recall >= cutoff_rec)[1]-1]) + curr_df <- data.frame(GRN=name_grn[ix], Type=type_grn[ix], GS=name_gs[jx], + TF=kx, Pcut=pcut, Rand=pr$Rand) + df_pr <- rbind(df_pr, curr_df) + } + } + } +} +df_pr$PcutLog2FC <- log2(df_pr$Pcut/df_pr$Rand) +df_pr$GS_TF <- paste(df_pr$GS, df_pr$TF, sep='_') + +# order gold standards +order_gs <- NULL +for (ix in name_gs){ + order_gs <- c(order_gs, paste(ix,'All TF',sep='_')) + curr_pr <- subset(df_pr, GS==ix & TF != 'All TF') + mat <- net_sparse_2_full(curr_pr[,c('GS_TF','GRN','PcutLog2FC')]) + curr_order <- colnames(mat)[cluster_hierarchical(mat)] + order_gs <- c(order_gs, curr_order) +} +df_pr$GS_TF <- factor(df_pr$GS_TF, levels=order_gs) + +# order networks +df_pr$GRN <- factor(df_pr$GRN, levels=rev(name_grn)) + + +# Axis label +label_gs <- order_gs +for(ix in name_gs){ + label_gs <- gsub(paste0(ix,'_'),'',label_gs) +} +face_gs <- rep('plain',length(label_gs)) +face_gs[which(label_gs=='All TF')] <- 'bold' + +tf_ko <- unique(subset(df_pr, GS=='KO' & TF != 'All TF')$TF) +tf_chip <- unique(subset(df_pr, GS=='ChIP' & TF != 'All TF')$TF) +tf_int <- intersect(tf_ko, tf_chip) +for (ix in tf_int){ + face_gs[which(label_gs==ix)[1]] <- 'italic' +} + +# Lines separating gold standards +# vline_sep <- ceiling(cumsum(table(df_pr$GS)[name_gs]/length(name_grn)))+0.5 +tf_counts <- tapply(df_pr$TF, df_pr$GS, function(x) length(unique(x))) +vline_sep <- cumsum(tf_counts[name_gs]) + 0.5 +vline_sep <- vline_sep[1:(length(vline_sep)-1)] + +# lines separating network types +order_type <- rev(unique(type_grn)) +hline_sep <- cumsum(table(type_grn)[order_type])+0.5 +hline_sep <- hline_sep[1:(length(hline_sep)-1)] + +color_pal <- c(low='dodgerblue3', 'white', high='firebrick3') # c('#2c7bb6','#abd9e9','white','#d7191c','#b10026') +# Heatmap precision at cutoff +print('Heatmap precision at cutoff') +ggplot(df_pr, aes(x=GS_TF, y=GRN, fill=PcutLog2FC)) + + geom_tile() + + # geom_tile(data = subset(df_pr, !is.na(PcutLog2FC))) + # Removes NA tiles + # geom_tile(data = subset(df_pr, is.na(PcutLog2FC)), fill = "lightgrey") + # Adds NA tiles without borders + geom_vline(xintercept=vline_sep, lwd=0.8, colour='gray30') + + geom_hline(yintercept=hline_sep, lwd=0.8, colour='gray30') + + labs(x=' ', y=' ', fill='Log2 FC') + + theme(axis.text=element_text(size=12,face='plain')) + + theme(axis.title=element_text(size=14,face='plain')) + + theme(axis.text.x=element_text(angle=45, hjust=1)) + + theme(axis.text.x=element_text(face=face_gs)) + + # scale_fill_gradient2(low='dodgerblue3', high='firebrick3', limits=range_fc, na.value="lightgrey",oob=squish) + + scale_fill_gradientn( + # colors=c(low='dodgerblue3', high='firebrick3'), + limits=range_fc, + colors= color_pal, + breaks=breaks_fc, + na.value="lightgrey",oob=squish + )+ + scale_x_discrete(labels=label_gs) +file_out <- file.path(dir_out, paste0('heat_',file_save,'_PRcut',label_pcut,'.pdf')) +ggsave(file_out, height=heat_h, width=heat_w) + +ggplot(df_pr, aes(x=GS_TF, y=GRN, fill=PcutLog2FC)) + + geom_tile() + + # geom_tile(data = subset(df_pr, !is.na(PcutLog2FC))) + # Removes NA tiles + # geom_tile(data = subset(df_pr, is.na(PcutLog2FC)), fill = "lightgrey") + # Adds NA tiles without borders + geom_vline(xintercept=vline_sep, lwd=0.8, colour='gray30') + + geom_hline(yintercept=hline_sep, lwd=0.8, colour='gray30') + + labs(x=' ', y=' ', fill=' ') + + theme(axis.text=element_text(size=12,face='plain')) + + theme(axis.title=element_text(size=14,face='plain')) + + theme(axis.text.x=element_text(angle=45, hjust=1)) + + theme(axis.text.x=element_text(face=face_gs)) + + # scale_fill_gradient2(low='dodgerblue3', high='firebrick3', limits=range_fc, na.value="lightgrey",oob=squish) + + scale_fill_gradientn( + # colors=c(low='dodgerblue3', high='firebrick3'), + limits=range_fc, + colors= color_pal, + breaks=breaks_fc, + na.value="lightgrey",oob=squish + )+ + scale_x_discrete(labels=label_gs) + + theme(legend.position='bottom', legend.text=element_text(size=16), legend.key.width=unit(1.5,'cm')) +file_out <- file.path(dir_out, paste0('heat_',file_save,'_PRcut',label_pcut,'_2.pdf')) +ggsave(file_out, height=heat_h, width=heat_w) + +# Boxplot log2fc +print('boxplot log2fc') +for (ix in name_gs){ + df_pr_sub <- subset(df_pr, GS==ix & TF != 'All TF') + ggplot(df_pr_sub, aes(x=GRN, y=PcutLog2FC)) + geom_boxplot() + coord_flip() + + theme_minimal() + + labs(x=' ', y='Log2 Fold-Change') + + theme(axis.text=element_text(size=12,face='plain')) + + theme(axis.title=element_text(size=14,face='plain')) + + theme(axis.line=element_line(size=0.6, colour='grey30', linetype='solid')) + file_out <- file.path(dir_out, paste0('box_',file_save,'_PRcut',label_pcut,'_',ix,'.pdf')) + ggsave(file_out, height=heat_h, width=box_w) +} + + + +# set regions where GS is tested on itself to NA +df_pr_grey <- df_pr +df_pr_grey$PcutLog2FC <- ifelse((df_pr$Type == "ChIP" & df_pr$GS != "KO") | + (df_pr$Type == "KO" & df_pr$GS != "ChIP"), + NA, df_pr$PcutLog2FC) + +# Heatmap precision at cutoff +print('Heatmap precision at cutoff - greyed out GS tested regions') +ggplot(df_pr_grey, aes(x=GS_TF, y=GRN, fill=PcutLog2FC)) + + geom_tile() + + # geom_tile(data = subset(df_pr, !is.na(PcutLog2FC))) + # Removes NA tiles + # geom_tile(data = subset(df_pr, is.na(PcutLog2FC)), fill = "lightgrey") + # Adds NA tiles without borders + geom_vline(xintercept=vline_sep, lwd=0.8, colour='gray30') + + geom_hline(yintercept=hline_sep, lwd=0.8, colour='gray30') + + labs(x=' ', y=' ', fill='Log2 FC') + + theme(axis.text=element_text(size=12,face='plain')) + + theme(axis.title=element_text(size=14,face='plain')) + + theme(axis.text.x=element_text(angle=45, hjust=1)) + + theme(axis.text.x=element_text(face=face_gs)) + + # scale_fill_gradient2(low='dodgerblue3', high='firebrick3', limits=range_fc, na.value="lightgrey",oob=squish) + + scale_fill_gradientn( + limits=range_fc, + colors= color_pal, + breaks=breaks_fc, + na.value="lightgrey",oob=squish + )+ + scale_x_discrete(labels=label_gs) +file_out <- file.path(dir_out, paste0('heat_',file_save,'_PRcut',label_pcut,'_greyed_out.pdf')) +ggsave(file_out, height=heat_h, width=heat_w) + +ggplot(df_pr_grey, aes(x=GS_TF, y=GRN, fill=PcutLog2FC)) + + geom_tile() + + # geom_tile(data = subset(df_pr, !is.na(PcutLog2FC))) + # Removes NA tiles + # geom_tile(data = subset(df_pr, is.na(PcutLog2FC)), fill = "lightgrey") + # Adds NA tiles without borders + geom_vline(xintercept=vline_sep, lwd=0.8, colour='gray30') + + geom_hline(yintercept=hline_sep, lwd=0.8, colour='gray30') + + labs(x=' ', y=' ', fill=' ') + + theme(axis.text=element_text(size=12,face='plain')) + + theme(axis.title=element_text(size=14,face='plain')) + + theme(axis.text.x=element_text(angle=45, hjust=1)) + + theme(axis.text.x=element_text(face=face_gs)) + + # scale_fill_gradient2(low='dodgerblue3', high='firebrick3', limits=range_fc, na.value="lightgrey",oob=squish) + + scale_fill_gradientn( + # colors=c(low='dodgerblue3', high='firebrick3'), + limits=range_fc, + colors= color_pal, + breaks=breaks_fc, + na.value="lightgrey",oob=squish + )+ + scale_x_discrete(labels=label_gs) + + theme(legend.position='bottom', legend.text=element_text(size=16), legend.key.width=unit(1.5,'cm')) +file_out <- file.path(dir_out, paste0('heat_',file_save,'_PRcut',label_pcut,'_greyed_out_2.pdf')) +ggsave(file_out, height=heat_h, width=heat_w) + +# Boxplot log2fc +print('boxplot log2fc') +for (ix in name_gs){ + df_pr_sub <- subset(df_pr_grey, GS==ix & TF != 'All TF') + ggplot(df_pr_sub, aes(x=GRN, y=PcutLog2FC)) + geom_boxplot() + coord_flip() + + theme_minimal() + + labs(x=' ', y='Log2 Fold-Change') + + theme(axis.text=element_text(size=12,face='plain')) + + theme(axis.title=element_text(size=14,face='plain')) + + theme(axis.line=element_line(size=0.6, colour='grey30', linetype='solid')) + file_out <- file.path(dir_out, paste0('box_',file_save,'_PRcut',label_pcut,'_',ix,'_greyed_out.pdf')) + ggsave(file_out, height=heat_h, width=box_w) +} +print('DONE') diff --git a/evaluation/R/TFA_Viz.R b/evaluation/R/TFA_Viz.R new file mode 100755 index 0000000..8a89c1b --- /dev/null +++ b/evaluation/R/TFA_Viz.R @@ -0,0 +1,245 @@ +rm(list = ls()) +options(stringsAsFactors=FALSE) +suppressPackageStartupMessages({ + library(ComplexHeatmap) + library(dplyr) + library(tidyverse) + library(circlize) + library(RColorBrewer) + library(gridtext) + library(readxl) + library(reshape2) + library(Seurat) + library(ggplot2) +}) + +source("/data/miraldiNB/Michael/Scripts/standardize_normalize_pseudobulk.R") +source("/data/miraldiNB/Michael/Scripts/GSEA/GSEA_utils.R") + +dirOut <- "/data/miraldiNB/Michael/mCD4T_Wayman/Figures/Fig4/Test" +dir.create(dirOut, showWarnings = F, recursive = T) + + +# network <- "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.tsv" +tfa_file <- "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/TFA.txt" +file_k_clust_across <- NULL #file.path(dir_out, 'k_clust_across.rds') +file_k_clust_within <- NULL #file.path(dir_out, 'k_clust_within.rds') + +order_celltype <- c('Tfh10','Tfh_Int','Tfh', + 'Tfr','cTreg','eTreg','rTreg','Treg_Rorc', + 'Th17','Th1','CTL_Prdm1','CTL_Bcl6', + #'TM_Act','TM_ISG', + 'TEM','TCM') +order_age <- c('Young','Old') +order_rep <- c('R1','R2','R3','R4') + +tfa <- read.table(tfa_file) +meta_data <- NULL +for (ix in order_celltype){ + for (jx in order_age){ + curr_sample <- paste(order_rep, jx, ix, sep='_') + curr_meta <- data.frame(row.names=curr_sample, CellType=ix, Age=jx, Rep=order_rep) + meta_data <- rbind(meta_data, curr_meta) + } +} + +# diff <- setdiff(rownames(meta_data), colnames(tfa)) +meta_data <- meta_data[intersect(colnames(tfa), rownames(meta_data)), ] + +# rearrange tfa matrix +tfa <- tfa[ ,rownames(meta_data)] +# Z-scoring or Standardization +z_tfa <- standardizeAndNormalizeCounts(tfa, meta_data = meta_data, celltype_var = "CellType", epsilon = NULL) +tfa_z_across = z_tfa$z_across +tfa_z_within = z_tfa$z_within + +# Create Annotation Color +getPalette = colorRampPalette(brewer.pal(13, "Set1")) +celltype_colors = getPalette(length(unique(meta_data[,1]))) +celltype_colors <- setNames(celltype_colors, c(unique(meta_data[,1]))) +treatment_colors <- c('Young'='grey66', 'Old'='grey33') + +meta_data[,2] <- factor() +#Create annotation column +colAnn_top <- HeatmapAnnotation( + `Cell Type` = meta_data[,1], + `Age` = meta_data[,2], + col = list('Cell Type' = celltype_colors, + 'Age' = treatment_colors), + show_annotation_name = FALSE + ) + +across_center <- 6 +# within_center <- 4 +gc() +if (is.null(file_k_clust_across)){ + k_clust_across <- kmeans(tfa_z_across, centers=across_center,nstart=20, iter.max = 50) +}else{ + k_clust_across <- readRDS(file_k_clust_across) +} + +# gc() +# if (is.null(file_k_clust_within)){ +# k_clust_within <- kmeans(tfa_z_within, centers=within_center, nstart=20, iter.max = 50) +# }else{ +# k_clust_within <- readRDS(file_k_clust_within) +# } + +# # Cluster +k_clust_across <- kmeans(tfa_z_across, centers=6,nstart=20, iter.max = 50) +split_by <- factor(k_clust_across$cluster, levels = c(1:6)) + +heat_col <- colorRamp2(c(-2, 0, 2), c('dodgerblue3','white','red')) ## Colors for heatmap +pdf(file.path(dirOut, "htmap_zscore_across.pdf"), width = 4, height = 6, compress = T) +ht1 <- Heatmap(tfa_z_across, + name='Z-score', + show_row_names = F, + row_split=split_by, + row_gap = unit(0, "mm"), + column_gap = unit(0, "mm"), + border = TRUE, + row_title = NULL, + row_dend_reorder = F, + use_raster = F, + col = heat_col, + # clustering settings + cluster_rows = TRUE, # allow hierarchical clustering + cluster_row_slices = TRUE, # reorder within each k-means cluster + cluster_columns = FALSE, + cluster_column_slices = F, + # column_order = colnames(tfa_z_across), + + # row/column dendrograms + show_row_dend = FALSE, # show dendrogram within slices + show_column_dend = FALSE, + + show_column_names = FALSE, + column_names_side = 'top', + # column_names_rot = 45, + column_title = NULL, + column_split = meta_data[,1], + top_annotation = colAnn_top + #bottom_annotation = colAnn_bottom + ) +draw(ht1) +dev.off() + + +# cluster_annot <- data.frame(across=k_clust_across$cluster,within=k_clust_within$cluster) +# cluster_annot$cluster <- paste0(cluster_annot$across,cluster_annot$within) +# cluster_annot <- cluster_annot[order(cluster_annot$cluster), ] +# within_clust_order <- c(1,2,3,4) +# across_clust_order <- c(2,3,5,4,6,1) +# cluster_annot$within <- factor(cluster_annot$within, levels=within_clust_order, ordered=T) +# cluster_annot$across <- factor(cluster_annot$across, levels=across_clust_order, ordered=T) + + +# annot_cols_across <- c('1'='#cc0001','2'='#fb940b','3'='#ffff01','4'='#01cc00','5'='#2085ec','6'='#fe98bf','7'='#762ca7','8'='#ad7a5b', +# '9'='grey50', '10'='turquoise4') + +# annot_cols_within <- c('1'='#badf55','2'='#35b1c9','3'='#b06dad','4'="#14A76C", '5' = 'grey90', '6' = '#e96060') + +# #reorder clusters +# split <- factor(cluster_annot$cluster, levels=c( +# 21,22,23,24, +# 31,32,33,34, +# 51,52,53,54, +# 41,42,43,44, +# 61,62,63,64, +# 11,12,13,14 +# )) + +# cluster_annot$cluster <- factor(cluster_annot$cluster, levels=levels(split)) +# cluster_annot <- cluster_annot[order(cluster_annot$cluster),] +# names(split) <- rownames(cluster_annot) + +# rowAnn_across_left1 <- HeatmapAnnotation(`Across` = cluster_annot[,'across'], #df = cluster_annot[,'across'], +# col = list('Across'= annot_cols_across), +# which = 'row', +# simple_anno_size = unit(3, 'mm'), +# #annotation_width = unit(c(1, 4), 'cm'), +# #gap = unit(0, 'mm'), +# show_annotation_name = F) + + +# pdf(file.path(dirOut, "htmap_zscore_across1.pdf"), width = 4, height = 6, compress = T) +# ht1 <- Heatmap(tfa_z_across[rownames(cluster_annot),], +# name='Z-score', +# show_row_names = F, +# row_split=cluster_annot$across, +# show_column_dend =F, +# row_gap = unit(0, "mm"), +# column_gap = unit(0, "mm"), +# border = TRUE, +# row_title = NULL, +# row_dend_reorder = F, +# use_raster = F, +# col = heat_col, +# column_names_gp = gpar(fontsize = 8, fontface = 2), +# show_row_dend = FALSE, +# cluster_rows = TRUE, +# cluster_columns = F, +# cluster_row_slices = F, +# cluster_column_slices = F, +# column_order = colnames(tfa_z_across), +# show_column_names = FALSE, +# column_names_side = 'top', +# # column_labels = gt_render(column_label), +# column_names_rot = 45, +# column_title = NULL, +# column_split = meta_data[,1], +# left_annotation = rowAnn_across_left1, +# # right_annotation = rowAnn_across_right1, +# row_order = rownames(cluster_annot), +# top_annotation = colAnn_top +# #bottom_annotation = colAnn_bottom +# ) + + + +# PART C: TFA Visualization between data representation +tfaFiles <- list( + "PB" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/TFA.txt", + "SC" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/SC/ATAC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5/Combined/TFA.txt", + "MC2" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/metaCells/MC2/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63_logNorm/Combined/TFA.txt", + "SEA" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/metaCells/SEACells/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/TFA.txt" +) + +tfaMatList <- list() +for (typeName in names(tfaFiles)) { + filePath <- tfaFiles[[typeName]] + tfaMatList[[typeName]] <- read.table(filePath, header = TRUE, sep = "\t", stringsAsFactors = FALSE) +} + +numTFA <- length(tfaMatList) +tfaNames <- names(tfaMatList) + + + +# ----------------------------- +# Feature Plot +# ----------------------------- +objPath <- "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/annotation_scrna_final/obj_Tfh10_RNA_annotated.rds" +tfa_file1 <- "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/SC/ATAC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5/Combined/TFA.txt" +obj <- readRDS(objPath) +tfa_mat1 <- read.table(tfa_file1, header = T) +lineage_TFs <- c("Bcl6","Maf","Batf","Tox","Ascl2", + "Foxp3","Ikzf2","Rorc", + "Rorc","Stat3", + "Tbx21","Stat4","Eomes","Prdm1", + "Tcf7","Klf2","Lef1") + +# Filter lineage TFs present in your TFA matrix +lineage_TFs_present <- intersect(lineage_TFs, rownames(tfa_mat)) + +# Add each TFA as metadata +for(tf in lineage_TFs_present){ + obj[[tf]] <- tfa_mat[tf, colnames(obj)] +} + +# Example: FeaturePlot for all lineage TFs +FeaturePlot(seurat_obj, + features = lineage_TFs_present, + cols = c("lightgrey","red"), + reduction = "umap") # or "tsne" if you use tSNE + diff --git a/evaluation/R/evaluateNetUtils.R b/evaluation/R/evaluateNetUtils.R new file mode 100755 index 0000000..bea8551 --- /dev/null +++ b/evaluation/R/evaluateNetUtils.R @@ -0,0 +1,644 @@ +suppressPackageStartupMessages({ + library(dplyr) + library(tidyr) + library(ggplot2) + library(ComplexHeatmap) + library(pheatmap) + # library(VennDiagram) + # library(ggVennDiagram) + library(grid) + library(reshape2) + library(RColorBrewer) + +}) + +# ========================== +# PART A: EDGE SET CONSTRUCTION +# ========================== + +# -------- 1. Build Edge Sets ----------- +getEdgeSets <- function(netDataList, tfCol="TF", targetCol="Gene", rankCol="signedQuantile", + mode = c("topNpercent", "topN"), N = NULL) { + #' getEdgeSets + #' + #' Extract edge sets from a list of networks, optionally selecting top-ranked edges. + #' + #' @param netDataList A named list of data.frames. Each data.frame represents a network with at least + #' the columns for TF, target gene, and ranking. + #' @param tfCol Character. Name of the column containing transcription factors (default "TF"). + #' @param targetCol Character. Name of the column containing target genes (default "Gene"). + #' @param rankCol Character. Name of the column containing ranking scores (default "signedQuantile"). + #' @param mode Character. Method to select edges when N is specified. Choices are: + #' "topNpercent" (default) - select top N percent of edges based on rankCol, + #' "topN" - select top N edges. + #' Ignored if N is NULL. + #' @param N Numeric. Number of edges (for "topN") or percent (for "topNpercent") to select. + #' If NULL (default), all edges are returned and mode is ignored. + #' + #' @return A named list of character vectors. Each element contains edges in the form "TF~Gene". + #' + mode <- match.arg(mode) + + edgeSets <- lapply(seq_along(netDataList), function(i) { + df <- netDataList[[i]] %>% + filter(.data[[rankCol]] != 0) %>% + arrange(desc(.data[[rankCol]])) + + if(!rankCol %in% colnames(df)) stop(paste("Column", rankCol, "not found in network", names(netDataList)[i])) + + df <- df %>% + mutate(edge = paste(.data[[tfCol]], .data[[targetCol]], sep="~")) + # If N is NULL, return all edges + if (is.null(N)) { + return(df$edge) + } + + if(mode == "topN") { + df <- df %>% slice_head(n = N) + } else { # topNpercent + cutoff <- quantile(df[[rankCol]], 1 - N/100, na.rm = TRUE) + df <- df %>% filter(abs(.data[[rankCol]]) >= cutoff) + } + + df$edge + }) + + names(edgeSets) <- names(netDataList) + return(edgeSets) +} + +# ====================================================== +# PART B: Pairwise Overlap / Set Metrics & Correlation +# ====================================================== + +# - 1. Global Network Jaccard ----------- +computeJaccardMatrix <- function(edgeSets) { + networkNames <- names(edgeSets) + numNetworks <- length(edgeSets) + + pairInt <- matrix(NA, nrow=numNetworks, ncol=numNetworks, dimnames=list(networkNames, networkNames)) + pairJac <- matrix(NA, nrow=numNetworks, ncol=numNetworks, dimnames=list(networkNames, networkNames)) + + if (numNetworks > 1) { + for (i in 1:(numNetworks-1)) { + for (j in (i+1):numNetworks) { + shared <- intersect(edgeSets[[i]], edgeSets[[j]]) + un <- union(edgeSets[[i]], edgeSets[[j]]) + pairInt[i,j] <- length(shared) + pairJac[i,j] <- pairJac[j,i] <- if (length(un)) length(shared)/length(un) else NA + } + } + } + + diag(pairInt) <- vapply(edgeSets, length, integer(1)) + diag(pairJac) <- 1 + + return(list(pairJac = pairJac, pairInt = pairInt)) +} + + +# ============================================================================== +# - 2. Compute Spearman Correlation between pairs of ranks or weights +# ============================================================================== +#' +#' @param netDataList Named list of data.frames. Each data.frame must contain TF, target gene, and score columns. +#' @param edgeSets Optional named list of character vectors (from getEdgeSets). If provided, only these edges are used. +#' @param tfCol Character. Name of TF column (default "TF"). +#' @param targetCol Character. Name of target gene column (default "Gene"). +#' @param rankCol Character. Name of score column (default "signedQuantile"). +#' +#' @return A tibble with columns: Net1, Net2, Spearman correlation. +#' +#' @examples +#' # Spearman on all edges +#' computeSpearman(netDataList) +#' +#' # Spearman on top edges only +#' edges_top10pct <- getEdgeSets(netDataList, N=10) +#' computeSpearman(netDataList, edgeSets = edges_top10pct) +computeSpearman <- function(netDataList, + edgeSets = NULL, + ref = NULL, + tfCol="TF", + targetCol="Gene", + rankCol="signedQuantile") { + + prepForCor <- function(df, edges = NULL) { + df <- df %>% + filter(.data[[rankCol]] != 0) %>% + mutate(edge = paste(.data[[tfCol]], .data[[targetCol]], sep="~")) %>% + dplyr::select(edge, score = .data[[rankCol]]) + + if(!is.null(edges)) { + df <- df %>% filter(edge %in% edges) + } + + return(df) + } + + # Prepare each network + nets <- lapply(seq_along(netDataList), function(i) { + edges_to_use <- if(!is.null(edgeSets)) edgeSets[[names(netDataList)[i]]] else NULL + prepForCor(netDataList[[i]], edges = edges_to_use) + }) + + netNames <- names(netDataList) + names(nets) <- netNames + + # --- Compute pairwise Spearman correlations + # Build comparison pairs + if(!is.null(ref)){ + if(!ref %in% netNames){ + stop("Reference not found in edgeSets") + } + combs <- lapply(setdiff(netNames, ref), function(x) + c(ref, x)) + } else { + combs <- combn(netNames, 2, simplify = FALSE) + } + + map_dfr(combs, function(p) { + a <- nets[[p[1]]] + b <- nets[[p[2]]] + + joined <- inner_join(a, b, by = "edge", suffix = c(".1", ".2")) + rho <- if(nrow(joined) > 1) { + cor(joined$score.1, joined$score.2, method = "spearman", use = "complete.obs") + } else {NA} + + tibble( + Net1 = p[1], + Net2 = p[2], + Spearman = rho + ) + }) +} + + +# ==== 3.Compute Edge overlaps ==== +computeOverlapStats <- function(edgeSets, ref = NULL){ + nets <- names(edgeSets) + # Build comparison pairs + if(!is.null(ref)){ + if(!ref %in% nets){ + stop("Reference not found in edgeSets") + } + combs <- lapply(setdiff(nets, ref), function(x) + c(ref, x)) + } else { + combs <- combn(nets, 2, simplify = FALSE) + } + # combs <- combn(names(edgeSets), 2, simplify=FALSE) + + map_dfr(combs, function(p){ + A <- edgeSets[[p[1]]] + B <- edgeSets[[p[2]]] + + nA <- length(A) + nB <- length(B) + + shared <- length(intersect(A,B)) + un <- length(union(A,B)) + + # Dice Coefficient (or Sorensen-Dice Index) + dice <- ifelse( + nA + nB > 0, + 2 * shared / (nA + nB), + NA + ) + + # What fraction of the smaller set is shared with the larger set? + # Szymkiewicz–Simpson coefficient + fracOverlap = + ifelse(min(nA,nB)>0, + shared / min(nA,nB), + NA) + tibble( + Net1 = p[1], + Net2 = p[2], + pairID = paste0(p[1], "~", p[2]), + size1 = nA, + size2 = nB, + Intersection = shared, + frac1in2 = ifelse(nA > 0, shared / nA, NA), + frac2in1 = ifelse(nB > 0, shared / nB, NA), + Jaccard = ifelse(un> 0, shared/un, NA), + Dice = dice, + FractionOverlap = fracOverlap + + ) + }) +} + +# ==== 3.Compute TF overlaps ==== +computeTFoverlap <- function(netDataList, ref = NULL, tfCol = "TF") { + # Build TF sets + tfSets <- lapply(netDataList, function(df) unique(df[[tfCol]])) + + nets <- names(tfSets) + + # Build comparison pairs + if(!is.null(ref)){ + if(!ref %in% nets){ + stop("Reference not found in netDataList") + } + combs <- lapply(setdiff(nets, ref), function(x) c(ref, x)) + } else { + combs <- combn(nets, 2, simplify = FALSE) + } + + # Compute pairwise overlaps + map_dfr(combs, function(p){ + A <- tfSets[[p[1]]] + B <- tfSets[[p[2]]] + nA <- length(A) + nB <- length(B) + ov <- intersect(A,B) + shared <- length(ov) + un <- length(union(A,B)) + + # dice <- ifelse(nA + nB > 0, 2 * shared / (nA + nB), NA) + # fracOverlap <- ifelse(min(nA,nB) > 0, shared / min(nA,nB), NA) + + tibble( + Net1 = p[1], + Net2 = p[2], + pairID = paste0(p[1], "~", p[2]), + size1 = nA, + size2 = nB, + Overlaps = paste(ov, collapse = ","), + nIntersection = shared, + frac1in2 = ifelse(nA > 0, shared / nA, NA), + frac2in1 = ifelse(nB > 0, shared / nB, NA) + # Jaccard = ifelse(un > 0, shared/un, NA), + # Dice = dice, + # FractionOverlap = fracOverlap + ) + }) +} + +# ==================================================== +# PART C: TF-Level / Aggregated Metrics +# ==================================================== + +# ===== 1. TF-centric Jaccard ==== +#' For each TF, how consistent are its predicted targets across networks? +#' Biological networks are TF-driven, so checking target consistency per TF +#' Jaccard penalizes sets with different sizes. It measures proportion of shared elements out of all unique elements +#' +computeTFJaccard <- function(netDataList, tfList=NULL){ + networkNames <- names(netDataList) + numNetworks <- length(netDataList) + + allTFs <- unique(unlist(lapply(netDataList, function(df) df$TF))) + tfs <- if(!is.null(tfList)) intersect(allTFs, tfList) else allTFs + + pairNames <- combn(networkNames,2,function(x) paste(x, collapse="~")) + tfMatrix <- matrix(NA, nrow=length(tfs), ncol=length(pairNames), dimnames=list(tfs, pairNames)) + + for(tf in tfs){ + tfTargets <- lapply(netDataList, function(df) df$Gene[df$TF == tf]) + for(k in seq_along(pairNames)){ + pair <- combn(seq_len(numNetworks),2)[,k] + shared <- intersect(tfTargets[[pair[1]]], tfTargets[[pair[2]]]) + un <- union(tfTargets[[pair[1]]], tfTargets[[pair[2]]]) + tfMatrix[tf,k] <- if(length(un)) length(shared) /length(un) else NA + } + } + return(tfMatrix) +} + +# plotTFHeatmap <- function(tfMatrix, fileOut, fontsize_number=7, fontsize=9, fig_width = 6, fig_height = 6){ +# heatCols <- colorRampPalette(RColorBrewer::brewer.pal(9,"YlOrRd"))(100) +# pdf(fileOut, width = fig_width, height = fig_height) +# pheatmap(tfMatrix, +# display_numbers=TRUE, +# number_color="black", +# number_format="%.2f", +# fontsize_number=fontsize_number, +# fontsize=fontsize, +# cluster_rows=FALSE, +# cluster_cols=FALSE, +# show_rownames = TRUE, +# show_colnames = TRUE, +# legend = FALSE , # removes grid lines +# color=heatCols, +# main="" +# ) +# dev.off() +# } + +# ==== 2. Aggregated TF Jaccard per Network Pair ===== +computeAggregatedPairJaccard <- function(tfMatrix, fun=median){ + # tfMatrix: rows = TFs, cols = network pairs + aggVec <- apply(tfMatrix, 2, fun, na.rm=TRUE) + + # Convert to symmetric matrix + pairNames <- colnames(tfMatrix) + nets <- unique(unlist(strsplit(pairNames, "~"))) + numNetworks <- length(nets) + aggMat <- matrix(NA, nrow=numNetworks, ncol=numNetworks, dimnames=list(nets,nets)) + + for(k in seq_along(pairNames)){ + pair <- strsplit(pairNames[k], "~")[[1]] + aggMat[pair[1], pair[2]] <- aggVec[k] + aggMat[pair[2], pair[1]] <- aggVec[k] + } + diag(aggMat) <- 1 + return(aggMat) +} + + +#' ==== 3./ Compute and plot hub TF set overlap across networks ==== +#' +#' This function identifies the top N hub transcription factors (TFs) in each +#' network, computes pairwise similarity between hub sets across networks using +#' either the Jaccard index or the overlap coefficient, and plots a heatmap of +#' the similarity matrix. +#' +#' @param netDataList List of data frames, one per network, each containing edges. +#' @param tfCol Column name (string) indicating which column contains TF identifiers. +#' @param networkNames Character vector of network names, matching the order of netDataList. +#' @param dirOut Output directory for saving the heatmap PDF. +#' @param topN Integer. Number of top hubs (by edge count) to include per network. Default = 50. +#' @param metric Similarity metric. One of "jaccard" or "overlap". Default = "jaccard". +#' @param heatCols Color palette for the heatmap. Default = 100-color YlOrRd. +#' @param fontsize Base font size for heatmap text. Default = 9. +#' @param fontsize_number Font size for numbers displayed in cells. Default = 7. +#' #' @param fig_width Width of the PDF figure in inches. Default = 6. +#' @param fig_height Height of the PDF figure in inches. Default = 6. + +#' +#' @return Invisibly returns the similarity matrix used to generate the heatmap. +#' +#' @examples +#' hubSim <- computeHubOverlapHeatmap( +#' netDataList = netDataList, +#' tfCol = "TF", +#' networkNames = c("Net1", "Net2", "Net3"), +#' dirOut = "results/", +#' topN = 50, +#' metric = "jaccard", +#' fig_width = 8, +#' fig_height = 8 +#' ) +#' +#' @export +computeHubOverlapHeatmap <- function(netDataList, tfCol, networkNames, + dirOut, topN = 50, + metric = c("jaccard", "overlap"), + heatCols = colorRampPalette(RColorBrewer::brewer.pal(9, "YlOrRd"))(100), + fontsize = 9, fontsize_number = 7, fig_width = 6, fig_height = 6) { + # --- Argument check + metric <- match.arg(metric) + numNetworks <- length(netDataList) + + # --- 1. Count edges per TF and rank hubs + hubRankList <- lapply(netDataList, function(df) { + tfCounts <- table(df[[tfCol]]) + sort(tfCounts, decreasing = TRUE) + }) + + # --- 2. Take top N hubs + topHubs <- lapply(hubRankList, function(x) head(names(x), topN)) + + # --- 3. Initialize similarity matrix + hubSim <- matrix(NA, numNetworks, numNetworks, + dimnames = list(networkNames, networkNames)) + + # --- 4. Compute similarity + for (i in 1:(numNetworks-1)) { + for (j in (i+1):numNetworks) { + inter <- intersect(topHubs[[i]], topHubs[[j]]) + + if(metric == "jaccard") { + union <- union(topHubs[[i]], topHubs[[j]]) + hubSim[i,j] <- hubSim[j,i] <- length(inter) / length(union) + } else if(metric == "overlap") { + minSize <- min(length(topHubs[[i]]), length(topHubs[[j]])) + hubSim[i,j] <- hubSim[j,i] <- length(inter) / minSize + } + } + } + diag(hubSim) <- 1 + + # --- 5. Plot heatmap + pdf(file.path(dirOut, paste0("HubTF_Overlap_", metric, "_top", topN, ".pdf")), width = fig_width, height = fig_height) + pheatmap(hubSim, + display_numbers = TRUE, + number_color = "black", + number_format = "%.2f", + fontsize_number = fontsize_number, + fontsize = fontsize, + cluster_rows = FALSE, + cluster_cols = FALSE, + show_rownames = TRUE, + show_colnames = TRUE, + legend = FALSE, + color = heatCols, + main = "") + dev.off() + + invisible(hubSim) +} + +# ================================================== +# - Plotting Ulis +# ================================================== + + +# ==== Edge Sharing ==== +plotEdgeSharing <- function(edgeSets, fileSuffix, outDir=".") { + library(ggVennDiagram) + library(ComplexUpset) + library(ComplexHeatmap) + + numNetworks <- length(edgeSets) + dir.create(outDir, showWarnings = FALSE, recursive = TRUE) + + if(numNetworks <= 3){ + # --- Venn diagram + pdf(file.path(outDir, paste0("Venn2_", fileSuffix, ".pdf")), width=6, height=6) + ggVennDiagram(edgeSets, label_alpha = 0.6) + + scale_fill_gradient(low="grey90", high="red") + + theme(legend.position="none") + dev.off() + + } else { + # --- UpSet plots for >2 networks + # Convert edgeSets to a presence/absence matrix + allEdges <- unique(unlist(lapply(edgeSets, function(df) paste(df$TF, df$Gene, sep="~")))) + incMat <- sapply(edgeSets, function(df) as.integer(allEdges %in% paste(df$TF, df$Gene, sep="~"))) + colnames(incMat) <- names(edgeSets) + rownames(incMat) <- allEdges + + # Intersect mode + mObjIntersect <- make_comb_mat(incMat, mode="intersect") + pdf(file.path(outDir, "UpSet_Intersect.pdf"), width=8.5, height=4) + draw(UpSet(mObjIntersect, + set_order = names(edgeSets), + top_annotation = upset_top_annotation(mObjIntersect, annotation_name_rot=90, axis=FALSE, + add_numbers=TRUE, numbers_rot=90, + gp=gpar(col=comb_degree(mObjIntersect), fontsize=6), height=unit(4,"cm")), + right_annotation = upset_right_annotation(mObjIntersect, gp=gpar(fill="black", fontsize=6), + width=unit(4,"cm"), show_annotation_name=FALSE, add_numbers=TRUE))) + dev.off() + + # Distinct mode + mObjDistinct <- make_comb_mat(incMat, mode="distinct") + pdf(file.path(outDir, paste0("UpSet_distinct_", fileSuffix, ".pdf")), width=8.5, height=4) + draw(UpSet(mObjDistinct, + set_order = names(edgeSets), + top_annotation = upset_top_annotation(mObjDistinct, annotation_name_rot=90, axis=FALSE, + add_numbers=TRUE, numbers_rot=90, + gp=gpar(col=comb_degree(mObjDistinct), fontsize=6), height=unit(4,"cm")), + right_annotation = upset_right_annotation(mObjDistinct, gp=gpar(fill="black", fontsize=6), + width=unit(4,"cm"), show_annotation_name=FALSE, add_numbers=TRUE))) + dev.off() + } + + message("Edge sharing plots generated.") +} + + +plotSimilarityHeatmap <- function(simMat, fileOut="sim_heatmap.pdf", + fontsize_number=7, fontsize=9, fig_width=6, fig_height=6) { + simMat[lower.tri(simMat)] <- t(simMat)[lower.tri(simMat)] + heatCols <- colorRampPalette(RColorBrewer::brewer.pal(9,"YlOrRd"))(100) + pdf(fileOut, height = fig_height, width = fig_width) + pheatmap::pheatmap(simMat, + display_numbers=TRUE, + number_color="black", + number_format="%.2f", + fontsize_number=fontsize_number, + fontsize=fontsize, + cluster_rows=FALSE, + cluster_cols=FALSE, + show_rownames = TRUE, + show_colnames = TRUE, + legend = FALSE , + color=heatCols, + main="" + ) + dev.off() +} + +plotAggregatedJaccard <- function(aggMat, fileOut, fontsize_number=7, fontsize=9, fig_width = 6, fig_height = 6){ + heatCols <- colorRampPalette(RColorBrewer::brewer.pal(9,"YlOrRd"))(100) + pdf(fileOut, width = fig_width, height = fig_height) + pheatmap(aggMat, + display_numbers=TRUE, + number_color="black", + number_format="%.2f", + fontsize_number=fontsize_number, + fontsize=fontsize, + cluster_rows=FALSE, + cluster_cols=FALSE, + show_rownames = TRUE, + show_colnames = TRUE, + legend = FALSE , # removes grid lines + color=heatCols, + main="Aggregated TF Jaccard per Network Pair") + dev.off() +} + +plotMetricTrend <- function(df, xCol, yCol, groupCol = "pairID", colorMap = NULL, + xLabel = NULL, yLabel = NULL, yLimits = NULL, xBreaks = NULL, + outFile = NULL, width = 6, height = 3, dpi = 600){ + p <- ggplot(df, + aes_string(x = xCol, y = yCol, color = groupCol, group = groupCol) + ) + + geom_line(linewidth = 0.5) + + geom_point(size = 1) + + theme_bw(base_family = "Helvetica") + + theme( + axis.text.x = element_text(size = 7, color = "black"), + axis.text.y = element_text(size = 7, color = "black"), + axis.title = element_text(size = 9, color = "black"), + legend.title = element_blank(), + legend.text = element_text(size = 9, color = "black"), + legend.position = "right", + legend.box.just = "left", + panel.grid.major = element_line(color = "grey80", linewidth = 0.3), + panel.grid.minor = element_line(color = "grey90", linewidth = 0.1) + ) + + # Optional scales + if(!is.null(colorMap)){ + p <- p + scale_color_manual(values = colorMap) + } + + if(!is.null(xBreaks)){ + p <- p + scale_x_continuous(breaks = xBreaks) + } + + if(!is.null(yLimits)){ + p <- p + scale_y_continuous(limits = yLimits) + } else{ + dataMin <- min(df[[yCol]], na.rm = TRUE) + dataMax <- max(df[[yCol]], na.rm = TRUE) + yMin <- if (dataMin >= 0) 0 else dataMin + yMax <- max(dataMax, 1) + yLimits <- c(yMin, yMax) + p <- p + scale_y_continuous(limits = yLimits) + } + # Labels + p <- p + labs( x = xLabel, y = yLabel ) + # Save if requested + if(!is.null(outFile)) ggsave(outFile, plot = p, width = width, height = height, dpi = dpi) + + return(p) +} + +# ========================================== +# High-Level Pipeline / API (Top Level) +# ========================================== +computeTopNMetrics <- function(netDataList, + ref = NULL, + topNs = NULL, + tfCol="TF", + targetCol="Target", + rankCol="Weight", + mode = c("topNpercent", "topN")) { + #' Compute Top-N overlap and correlation metrics for multiple networks + #' + #' @param netDataList Named list of network data.frames + #' @param topNs Numeric vector of percentages (Top N%) or number to evaluate + #' @param tfCol Character, column name for transcription factor + #' @param targetCol Character, column name for target gene + #' @param rankCol Character, column name for ranking score + #' + #' @return data.frame with columns: pairID, Jaccard, FractionOverlap, Spearman, topNs + mode = match.arg(mode) + if (is.null(topNs)) topNs <- "all" + topNResults <- lapply(topNs, function(k) { + cat("Top-N:", k, "\n") + + if (k == "all"){ + edges <- getEdgeSets(netDataList, tfCol = tfCol, + targetCol = targetCol, rankCol = rankCol) + } else{ + # Extract Top-N% edges + edges <- getEdgeSets(netDataList, tfCol = tfCol, + targetCol = targetCol, rankCol = rankCol, + mode = mode, N = k) + } + # Overlap statistics + stats <- computeOverlapStats(edges, ref) + stats$Npct <- if(k == "all") NA else k + + # Spearman correlation + corRes <- computeSpearman(netDataList, edgeSets = edges, ref = ref, tfCol = tfCol, targetCol = targetCol, rankCol = rankCol) + # Merge Spearman with stats + stats <- stats %>% + left_join(corRes %>% + mutate(pairID = paste(Net1, Net2, sep="~")) %>% + dplyr::select(pairID, Spearman), + by = "pairID") + + stats + }) + + bind_rows(topNResults) +} + diff --git a/evaluation/R/evaluateNetworks.R b/evaluation/R/evaluateNetworks.R new file mode 100755 index 0000000..ee2671f --- /dev/null +++ b/evaluation/R/evaluateNetworks.R @@ -0,0 +1,311 @@ +rm(list = ls()) +suppressPackageStartupMessages({ + library(dplyr) + library(tidyr) + library(ggplot2) + library(ComplexHeatmap) + library(pheatmap) + # library(VennDiagram) + library(ggVennDiagram) + library(grid) + library(reshape2) +}) + +source("/data/miraldiNB/Michael/Scripts/VennDiagram.R") + +# dirOut <- "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/Bulk/5kbTSS/newPipe/spikeIgnored/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63" +dirOut <- "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/networkEval/" +dir.create(dirOut, recursive = T, showWarnings = F) + +# Combined/combined_max.tsv +# TFmRNA/edges_subset.txt +# netFiles <- list( +# "TFA" = "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/SC/SCENIC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5_eRegulon/TFA/edges_subset.tsv", +# "TFmRNA" = "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/SC/SCENIC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5_eRegulon/TFmRNA/edges_subset.tsv", +# "Combined" = "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/SC/SCENIC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5_eRegulon/Combined/combined_max.tsv" +# ) + +netFiles <- list( + "TFA" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.tsv", + "TFmRNA" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.tsv", + "Combined" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/combined_max.tsv" + ) + +# priorFile = "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/MotifScan5kbTSS_b_sp.tsv" +priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10_sp.tsv" +# k <- read.table("/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/SCENICp/SCENICp_RE2Glinks_FIMO_5Kb_derived_b.tsv") +# k$Target <- rownames(k) +# b <- melt(k, id.vars = "Target", variable.names = "TF", value.name = "Weigths") +# colnames(b)[2:3] <- c("TF", "Weights") +# b <- b[, c("TF", "Target", "Weights")] +# bb <- b[b$Weights != 0, ] +# write.table(bb, "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/SCENICp/SCENICp_RE2Glinks_FIMO_5Kb_derived_sp.tsv", row.names = F, sep ="\t", quote = F) + +tfCol <- "TF" # specify TF column name here +targetCol <- "Gene" # specify Target column name here +priorTfCol <- "TF" # TF column name in prior file +priorTargetCol <- "Target" # Target column name in prior file +priorWgtCol <-"Weight" +nSelect <- 10 +stabCol <- "Stability" +compareNets <- TRUE +# ======================================================================== +# ------- Read Input files +# ======================================================================== + +# ----- Read Prior +priorData <- read.table(priorFile, header = TRUE, sep = "\t", stringsAsFactors = FALSE) +# priorData <- priorData[priorData[[priorWgtCol]] != 0, ] +priorPairs <- unique(paste(priorData[[priorTfCol]], priorData[[priorTargetCol]], sep = "_")) + +# --- Preload all network data into a list --- +netDataList <- list() +for (typeName in names(netFiles)) { + filePath <- netFiles[[typeName]] + netDataList[[typeName]] <- read.table(filePath, header = TRUE, sep = "\t", stringsAsFactors = FALSE) +} + +numNetworks <- length(netDataList) +networkNames <- names(netDataList) +# ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── +# ---- PART ONE: Make a table of number of network ppts - uniqueTFs, uniqueTargets, Total Interactions, and % supported by prior +# ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── + +# process each network +summaryList <- list() +for (typeName in networkNames) { + data <- netDataList[[typeName]] + + uniqueTf <- length(unique(data[[tfCol]])) + uniqueTarget <- length(unique(data[[targetCol]])) + + pairStrings <- unique(paste(data[[tfCol]], data[[targetCol]], sep = "_")) + totalInteractions <- length(pairStrings) + + supportedCount <- length(intersect(pairStrings, priorPairs)) + percentSupportedByPrior <- 100 * supportedCount / totalInteractions + + summaryList[[typeName]] <- data.frame(type = typeName, uniqueTf = uniqueTf, + uniqueTarget = uniqueTarget, totalInteractions = totalInteractions, + percentSupportedByPrior = percentSupportedByPrior) +} + +summaryTable <- do.call(rbind, summaryList) +print(summaryTable) + +write.table(as.data.frame(summaryTable), file.path(dirOut, "summaryStatistics.tsv"), col.names = NA, row.names = TRUE, quote = FALSE, sep = "\t") + + +# ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── +# ---- PART TWO: Compare Networks (can handle more than 2) +# ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── +if (compareNets){ + + #--- Build edge sets + edgeSets <- lapply(netDataList, function(df) paste(df[[tfCol]], df[[targetCol]], sep="~")) + names(edgeSets) <- names(netDataList) + networkNames <- names(netDataList) + numNetworks <- length(netDataList) + + #--- Edges shared between all networks + allShared <- Reduce(intersect, edgeSets) + + #1. --- Pairwise Jaccard, Intersections, and Heatmap + + if (numNetworks > 1) { + pairInt <- matrix(NA, numNetworks, numNetworks, dimnames=list(networkNames,networkNames)) + pairJac <- matrix(NA, numNetworks, numNetworks, dimnames=list(networkNames,networkNames)) + + for (i in 1:(numNetworks-1)) { + for (j in (i+1):numNetworks) { + shared <- intersect(edgeSets[[i]], edgeSets[[j]]) + un <- union(edgeSets[[i]], edgeSets[[j]]) + pairInt[i,j] <- length(shared) + pairJac[i,j] <- pairJac[j,i] <- if (length(un)) length(shared)/length(un) else NA + } + } + diag(pairInt) <- vapply(edgeSets, length, integer(1)) + diag(pairJac) <- 1 + + # heat-map of Jaccard ---------------------------------------- + heatCols <- colorRampPalette(RColorBrewer::brewer.pal(9,"YlOrRd"))(100) + pdf(file.path(dirOut, "Jaccard.pdf")) + pheatmap(pairJac, display_numbers = TRUE, main = "Pairwise Jaccard Index", color = heatCols) + dev.off() + + # -------- unique-edge counts ----------------------------------------- ## + uniqCountsDf <- data.frame( + network = names(edgeSets), + numUniqueEdges = sapply(seq_along(edgeSets), function(k){ + curr <- edgeSets[[k]] + others <- unique(unlist(edgeSets[-k])) + sum(!curr %in% others) + }), + row.names = NULL, + stringsAsFactors = FALSE + ) + + # 3. ------- Save Pairwise and uniqueDF counts + pairDf <- melt(pairInt, varnames = c("network1", "network2"), value.name = "numIntersection", na.rm = TRUE) + write.table(pairDf, file.path(dirOut, "num_Pairwise_Intersection.tsv"), quote = F, col.names = NA, row.names=TRUE, sep = "\t") + + } + + # 2.------ Venn or UpSet Plot + if (numNetworks >= 2) { + vennInput <- edgeSets + ggVennDiagram(vennInput, label_alpha = 0.6) + scale_fill_gradient(low="grey90", high = "red") + ggsave(file.path(dirOut, "Venn2.pdf")) + # venn.plot <- venn.diagram( + # vennInput, + # category.names = names(vennInput), + # filename = NULL, + # output = TRUE, + # main = "Shared TF~Gene Pairs" + # ) + # pdf(file.path(dirOut, "Venn2.pdf")) + # grid::grid.draw(venn.plot) + # dev.off() + + } else if (numNetworks > 3) { + # allEdges <- unique(unlist(edgeSets)) + # incMat <- sapply(edgeSets, function(edgeSet) as.integer(allEdges %in% edgeSet)) + # colnames(incMat) <- names(edgeSets) + # rownames(incMat) <- allEdges + upSetData <- edgeSets + + plotUpSet <- function(setsData, Mode){ + mObj <- make_comb_mat(setsData, mode = Mode) + ht <- UpSet(mObj, set_order = names(setsData), + top_annotation = upset_top_annotation(mObj, annotation_name_rot = 90, axis = FALSE, + add_numbers = TRUE, numbers_rot = 90, + gp = gpar(col = comb_degree(mObj), fontsize = 6), height = unit(4, "cm"), + axis_param = list(side = "left")), + right_annotation = upset_right_annotation(mObj, axis_param = list(side = "bottom"), #labels = FALSE,labels_rot = 0 + gp = gpar(fill = "black", fontsize = 6), width = unit(4, "cm"), show_annotation_name = FALSE, + add_numbers = TRUE, axis = FALSE + ) + ) + return(ht) + } + + pdf(file.path(dirOut, "UpSet_distinct.pdf"), width = 8.5, height = 4) + htDistinct <- plotUpSet(upSetData, Mode = "distinct") + draw(htDistinct) + dev.off() + + pdf(file.path(dirOut, "UpSet_Intersect.pdf"), width = 8.5, height = 4) + htIntersect <- plotUpSet(upSetData, Mode = "intersect") + draw(htIntersect) + dev.off() + } +} + + +# ──────────────────────────────────────────────────────────────────────────────────────────── +# PART THREE: Histogram distributions of: + # A: Targets per TF: # times each gene is targeted (number of TFs regulating it). + # B: TFs per Target: # times each TF is a regulator (number of genes it controls) + # C: Box-Plot of Stability of the top N low and high in degree genes + + #NOTE: Each Figure is saved in the same path as the network being evaluated +# ──────────────────────────────────────────────────────────────────────────────────────────── +# Generate both plots for each network +for (typeName in names(netDataList)) { + data <- netDataList[[typeName]] + basePath <- dirname(netFiles[[typeName]]) + + # Plot 1: # Targets per TF + tfTargetCounts <- table(data[[tfCol]]) + dfTF <- data.frame(TF = names(tfTargetCounts), targetCount = as.integer(tfTargetCounts)) + + p1 <- ggplot(dfTF, aes(x = targetCount)) + + geom_histogram(binwidth = 1, boundary = 0.5, fill = "#0072B2", color = "black", alpha = 0.8) + + # scale_x_continuous(breaks = seq(0, max(dfTF$targetCount), 1)) + + labs(title = paste("Distribution of # Targets per TF -", typeName), + x = "# Targets", y = "# TFs") + + theme_bw(base_family = "Helvetica") + + theme( + panel.grid.major.y = element_line(color = "grey80", linewidth = 0.3), + panel.grid.minor = element_blank(), + # axis.line = element_line(color = "black", linewidth = 0.4), + axis.text.x = element_text(size = 7), + axis.text.y = element_text(size = 7), + axis.title = element_text(size = 9), + plot.margin = margin(5, 5, 5, 5), + # panel.background = element_rect("black", fill = NA) + ) + + ggsave(file.path(basePath, paste0("TargetCountPerTF_", typeName, ".pdf")), + plot = p1, width = 3, height = 3) + + # Plot 2: # TFs per Target + targetTFCounts <- table(data[[targetCol]]) + dfTarget<- data.frame(Target = names(targetTFCounts), tfCount = as.integer(targetTFCounts)) + dfTargetSorted <- dfTarget[order(dfTarget$tfCount), ] + + p2 <- ggplot(dfTarget, aes(x = tfCount)) + + geom_histogram(binwidth = 1, boundary = 0.5, fill = "blue", color = "black", alpha = 0.7) + + # scale_x_continuous(breaks = seq(0, max(dfTarget$tfCount), 1)) + + labs(title = paste("Distribution of # TFs per Target -", typeName), + x = "# TFs", y = "# Targets") + + theme_bw(base_family = "Helvetica") + + theme( + panel.grid.major.y = element_line(color = "grey80", linewidth = 0.3), + panel.grid.minor = element_blank(), + # axis.line = element_line(color = "black", linewidth = 0.4), + axis.text.x = element_text(size = 7), + axis.text.y = element_text(size = 7), + axis.title = element_text(size = 9), + plot.margin = margin(5, 5, 5, 5), + # panel.background = element_rect("black", fill = NA) + ) + + ggsave(file.path(basePath, paste0("TFCountPerTarget_", typeName, ".pdf")), + plot = p2, width = 3, height = 3) + + # Select top N low and high TF targets + lowTargs <- dfTargetSorted$Target[1:nSelect] + highTargs <- tail(dfTargetSorted, nSelect)$Target + + stabilityDF <- data[data[[targetCol]] %in% c(lowTargs, highTargs), c(targetCol, stabCol)] + # Add Group column based on whether the target is in lowTargs or highTargs + stabilityDF$Group <- ifelse(stabilityDF[[targetCol]] %in% lowTargs, + "Low TFs per Target", + "High TFs per Target") + stabilityDF <- merge(stabilityDF, dfTargetSorted, by.x = targetCol, by.y = "Target") + stabilityDF$targetLabel <- paste0(stabilityDF[[targetCol]], "(", stabilityDF$tfCount, ")") + # Keep plotting order +# stabilityDF$targetLabel <- factor(stabilityDF$targetLabel, +# levels = unique(stabilityDF$targetLabel)) + # Order based on tfCount directly + stabilityDF$targetLabel <- factor( + stabilityDF$targetLabel, + levels = unique(stabilityDF$targetLabel[order(stabilityDF$tfCount)]) + ) + + # Plot: boxplot per target + p3 <- ggplot(stabilityDF, aes(x = targetLabel, y = Stability, fill = Group)) + + geom_boxplot(outlier.size = 0.5, alpha = 0.8) + + labs(title = paste("Per-Target Stability Distribution -", typeName), + x = "Number of TFs per Target (TF count)", y = "Stability") + + theme_minimal() + + theme_bw(base_family = "Helvetica") + + theme( + panel.grid.major.y = element_line(color = "grey80", linewidth = 0.3), + panel.grid.minor = element_blank(), + # axis.line = element_line(color = "black", linewidth = 0.4), + axis.text.x = element_text(size = 7, angle = 90, vjust = 0.5), + axis.text.y = element_text(size = 7), + axis.title = element_text(size = 9), + plot.margin = margin(5, 5, 5, 5), + legend.position = "top" + # panel.background = element_rect("black", fill = NA) + ) + + # Save + ggsave(file.path(basePath, paste0("Top", nSelect, "HighorLow_inDegreeGenes_Boxplot", typeName, ".pdf")), + plot = p3, width = 12, height = 5) + +} + diff --git a/evaluation/R/evaluateNetworks1.R b/evaluation/R/evaluateNetworks1.R new file mode 100755 index 0000000..f25e504 --- /dev/null +++ b/evaluation/R/evaluateNetworks1.R @@ -0,0 +1,708 @@ +rm(list = ls()) +suppressPackageStartupMessages({ + library(dplyr) + library(tidyr) + library(ggplot2) + library(ComplexHeatmap) + library(pheatmap) + library(circlize) + # library(VennDiagram) +# library(ggVennDiagram) + library(grid) + library(reshape2) + library(RColorBrewer) + library(purrr) + +}) + +source("/data/miraldiNB/Michael/Scripts/VennDiagram.R") +source("/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/RprogUtils/evaluateNetUtils.R") + +# dirOut <- "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/Bulk/5kbTSS/newPipe/spikeIgnored/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63" +dirOut <- "/data/miraldiNB/Michael/mCD4T_Wayman/Figures/Fig4/networkComparison1" +dir.create(dirOut, recursive = T, showWarnings = F) + +# Combined/combined_max.tsv +# TFmRNA/edges_subset.txt +# netFiles <- list( +# "TFA" = "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/SC/SCENIC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5_eRegulon/TFA/edges_subset.tsv", +# "TFmRNA" = "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/SC/SCENIC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5_eRegulon/TFmRNA/edges_subset.tsv", +# "Combined" = "/data/miraldiNB/Michael/hCD4T_Katko/Inferelator/noMergedTF/SC/SCENIC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5_eRegulon/Combined/combined_max.tsv" +# ) + +# netFiles <- list( +# "TFA" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/TFA/edges_subset.tsv", +# "TFmRNA" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/TFmRNA/edges_subset.tsv", +# "Combined" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/combined_max.tsv" +# ) + +netFiles <- list( + "PB" = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/Inferelator/noMergedTF/Bulk/ATAC/newPipe/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/combined_max.tsv", + "SC" = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/Inferelator/noMergedTF/SC/ATAC/geneLambda0p5_220totSS_20tfsPerGene_subsamplePCT5/Combined/combined_max.tsv" + # "MC2" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/metaCells/MC2/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63_logNorm/Combined/combined_max.tsv", + # "SEA" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/noMergedTF/metaCells/SEACells/networkLambda0p5_100totSS_20tfsPerGene_subsamplePCT63/Combined/combined_max.tsv" +) + +# priorFile = "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/MotifScan5kbTSS_b_sp.tsv" +priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10_sp.tsv" +# k <- read.table("/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/SCENICp/SCENICp_RE2Glinks_FIMO_5Kb_derived_b.tsv") +# k$Target <- rownames(k) +# b <- melt(k, id.vars = "Target", variable.names = "TF", value.name = "Weigths") +# colnames(b)[2:3] <- c("TF", "Weights") +# b <- b[, c("TF", "Target", "Weights")] +# bb <- b[b$Weights != 0, ] +# write.table(bb, "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/SCENICp/SCENICp_RE2Glinks_FIMO_5Kb_derived_sp.tsv", row.names = F, sep ="\t", quote = F) + +tfCol <- "TF" # specify TF column name here +targetCol <- "Gene" # specify Target column name here +priorTfCol <- "TF" # TF column name in prior file +priorTargetCol <- "Target" # Target column name in prior file +priorWgtCol <-"Weight" +nSelect <- 10 +stabCol <- "Stability" +rankCol <- "signedQuantile" +compareNets <- TRUE +tfList <- NULL +lineageTFs <-c("Bcl6","Maf","Batf","Tox","Ascl2", "Foxp3","Ikzf2","Rorc", "Rorc","Stat3", + "Tbx21","Stat4","Eomes","Prdm1","Tcf7","Klf2","Lef1") + +# lineageTFs <- c("TCF7","LEF1","ID3","KLF2","PRDM1","ZEB2","RUNX3","EOMES", +# "TBX21","STAT1","STAT4","HIF1A","RORC","RORA","STAT3","IRF4", +# "BATF","FOXP3","IKZF2","IKZF4","STAT5","FOXO1","GATA3","STAT6", +# "CIITA","RFX5","RFXAP","RFXANK","NLRC5") +N <- NULL +# file_k_clust_across <- "/Users/owop7y/Desktop/MiraldiLab/PROJECTS/GRN/Benchmark/mCD4T/Figures/Fig4/networkComparison/Kmeans_Jaccard_Edges_perTF.rds" +file_k_clust_across <- NULL + +# ======================================================================== +# ------- Read Input files +# ======================================================================== + +# ----- Read Prior +priorData <- read.table(priorFile, header = TRUE, sep = "\t", stringsAsFactors = FALSE) +head(priorData, 3) +priorData <- priorData[priorData[[priorWgtCol]] != 0, ] +priorPairs <- unique(paste(priorData[[priorTfCol]], priorData[[priorTargetCol]], sep = "_")) + +# --- Preload all network data into a list --- +netDataList <- list() +for (typeName in names(netFiles)) { + filePath <- netFiles[[typeName]] + netDataList[[typeName]] <- read.table(filePath, header = TRUE, sep = "\t", stringsAsFactors = FALSE) +} + +numNetworks <- length(netDataList) +networkNames <- names(netDataList) + +# ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── +# ---- PART ONE: Make a table of number of network ppts - uniqueTFs, uniqueTargets, Total Interactions, and % supported by prior +# ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── +# process each network +summaryList <- list() +for (typeName in networkNames) { + data <- netDataList[[typeName]] + + uniqueTf <- length(unique(data[[tfCol]])) + uniqueTarget <- length(unique(data[[targetCol]])) + + pairStrings <- unique(paste(data[[tfCol]], data[[targetCol]], sep = "_")) + totalInteractions <- length(pairStrings) + + supportedCount <- length(intersect(pairStrings, priorPairs)) + percentSupportedByPrior <- 100 * supportedCount / totalInteractions + + summaryList[[typeName]] <- data.frame(type = typeName, uniqueTf = uniqueTf, + uniqueTarget = uniqueTarget, totalInteractions = totalInteractions, + percentSupportedByPrior = percentSupportedByPrior) +} + +summaryTable <- do.call(rbind, summaryList) +print(summaryTable) + +write.table(as.data.frame(summaryTable), file.path(dirOut, "summaryStatistics.tsv"), col.names = NA, row.names = TRUE, quote = FALSE, sep = "\t") + +# ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── +# ---- PART TWO: Compare Networks (can handle more than 2) +# ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── + +# ──── Compare TF-Targets across methods +# Initialize a data frame to store results +tfTargetCounts <- data.frame(TF = unique(unlist(lapply(netDataList, function(x) unique(x$TF))))) + +# Count targets for each TF in each network +for(netName in names(netDataList)) { + netData <- netDataList[[netName]] + + # Count number of unique target genes per TF + targetCounts <- netData %>% + group_by(TF) %>% + summarise(targetCount = n()) + + colname <- netName + tfTargetCounts <- tfTargetCounts %>% + left_join(targetCounts, by = "TF") %>% + rename(!!colname := targetCount) +} +# Replace NA with 0 for TFs not present in a network +tfTargetCounts[is.na(tfTargetCounts)] <- 0 + +# Build file suffix +fileSuffix <- if(!is.null(topNpercent)) { + paste0(rankCol, "_top", topNpercent, "pct") +} else { + paste0(rankCol, "_all") +} +networkNames <- names(netDataList) +numNetworks <- length(netDataList) + +edgeSets <- getEdgeSets(netDataList, tfCol, targetCol, rankCol) + +# -------- unique-edge counts ----------------------------------------- ## +uniqCountsDf <- data.frame( + network = names(edgeSets), + numUniqueEdges = sapply(seq_along(edgeSets), function(k){ + curr <- edgeSets[[k]] + others <- unique(unlist(edgeSets[-k])) + sum(!curr %in% others) + }), + row.names = NULL, + stringsAsFactors = FALSE + ) + +write.table(uniqCountsDf, + file.path(dirOut, paste0("num_UniqueEdges_", fileSuffix, ".tsv")), + quote=F, col.names=NA, row.names=TRUE, sep="\t") + +# ------- Plot global Jaccard heatmap +# pairs <- plotGlobalJaccard(edgeSets, fileOut = file.path(dirOut, paste0("Global_Jaccard_",fileSuffix, ".pdf")), fig_width = 2.5, fig_height = 2.5) +# pairInt <- pairs[["pairInt"]] + +plotSimilarityHeatmap(pairs[["pairJac"]], fileOut=file.path(dirOut, paste0("Global_Jaccard_",fileSuffix, ".pdf")), fontsize_number=7, fontsize=9, fig_width=2.5, fig_height=2.5) +pairDf <- melt(pairs[["pairInt"]], varnames = c("network1", "network2"), value.name = "numIntersection", na.rm = TRUE) +write.table(pairDf,file.path(dirOut, paste0("num_Pairwise_Intersection_", fileSuffix, ".tsv")), quote=F, col.names=NA, row.names=TRUE, sep="\t") + + +# --------- Compute Jaccard for multiple topN% ------------- +topNpercentList <- seq(10, 100, by = 10) +jaccardResults <- list() +for (topN in topNpercentList) { + currEdgeSets <- getEdgeSets(netDataList, tfCol, targetCol, rankCol, "topNpercent", topN) + # Compute pairwise Jaccard + fileSuffix <- paste0(rankCol, "_top", topN, "pct") + pairs <- computeJaccardMatrix(currEdgeSets) + pairJac <- pairs[["pairJac"]] + # Convert pairwise matrix to long format + pairJacDf <- as.data.frame(as.table(pairJac)) + colnames(pairJacDf) <- c("Network1", "Network2", "Jaccard") + pairJacDf$TopNpercent <- topN + jaccardResults[[as.character(topN)]] <- pairJacDf +} +# Combine all percentages +jaccardDf <- do.call(rbind, jaccardResults) +# Filter out self-comparisons +jaccardDf <- jaccardDf[jaccardDf$Network1 != jaccardDf$Network2, ] +jaccardDf<- jaccardDf %>% + dplyr::mutate( + Network1 = as.character(Network1), + Network2 = as.character(Network2), + pairID = ifelse(Network1 < Network2, + paste(Network1, Network2, sep = "-"), + paste(Network2, Network1, sep = "-")) + ) %>% + dplyr::group_by(TopNpercent, pairID) %>% + dplyr::slice(1) %>% + dplyr::ungroup() +write.table(jaccardDf,file.path(dirOut, "Jaccard_Sim.tsv"), quote=F, col.names=NA, row.names=TRUE, sep="\t") + +# Create color mapping based on Set1 palette +uniquePairs <- sort(unique(jaccardDf$pairID)) +colors <- c("#449B75" ,"#FFE528", "#999999", "#E41A1C", "#AC5782", "#C66764") +comparisonColor = setNames(colors, uniquePairs) + +# Plot line graph with ggplot2 +p_jac <- ggplot(jaccardDf, aes(x = TopNpercent, y = Jaccard,color = pairID,group = pairID)) + + geom_line(linewidth = 0.5) + + geom_point(size = 1) + + scale_color_manual(values = comparisonColor) + + scale_x_continuous(breaks = seq(10, 100, by = 10)) + + scale_y_continuous(limits = c(0,1)) + + labs( + x = "Top N% edges", + y = "Jaccard similarity", + color = "Network pair" + ) + + theme_bw(base_family = "Helvetica") + + theme( + axis.text.x = element_text(size = 7, color = "black"), + axis.text.y = element_text(size = 7, color = "black"), + axis.title = element_text(size = 9, color = "black"), + legend.title = element_blank(), + legend.text = element_text(size = 9, color = "black"), + legend.position = "right", + legend.box.just = "left", # centers the legend + panel.grid.major = element_line(color = "grey80", linewidth = 0.3), + panel.grid.minor = element_line(color = "grey90", linewidth = 0.1) + ) +ggsave(file.path(dirOut, "linePlot_JaccardSim.pdf"),plot = p_jac, width = 4, height = 3, dpi = 600) + +#plotEdgeSharing(edgeSets, fileSuffix, outDir) + +# --- 3. TF-centric analysis +# For each TF, how consistent are its predicted targets across networks? +# Biological networks are TF-driven, so checking target consistency per TF +tfMatrix <- computeTFJaccard(netDataList, tfList) +# MAke Heatmap +getPalette = colorRampPalette(brewer.pal(13, "Set1")) +comparisonColor = getPalette(length(colnames(tfMatrix))) +comparisonColor = setNames(comparisonColor, colnames(tfMatrix)) +colAnn <- HeatmapAnnotation( + `Comparison` = colnames(tfMatrix), + col = list('Comparison' = comparisonColor), + annotation_legend_param = list( + Comparison = list( + title_gp = gpar(fontsize = 8, fontface = "plain"), # title size + labels_gp = gpar(fontsize = 6) # category label size + )), + simple_anno_size = unit(3, "mm"), + show_annotation_name = FALSE + ) + + +tfMatrix[is.na(tfMatrix)] <- 0 +tfRobustness <- rowMeans(tfMatrix, na.rm = TRUE) # average Jaccard per TF + +k_center = 6 +if (is.null(file_k_clust_across)){ + k_clust_across <- kmeans(tfMatrix, centers=k_center,nstart=20, iter.max = 50) +}else{ + k_clust_across <- readRDS(file_k_clust_across) +} +split_by <- factor(k_clust_across$cluster, levels = c(1:6)) + +# Add row markers (labels with lines) +row_mark <- rowAnnotation( + mark = anno_mark( + at = which(rownames(tfMatrix) %in% lineageTFs), + labels = lineageTFs, + labels_gp = gpar(fontsize = 6, fontface = "plain"), + padding = unit(0.3, "mm"), + side = "left", # line starts from left of heatmap + ) +) + +annot_cols <- c('1'='#fb940b','2'='#01cc00','3'='#2085ec','4'='#fe98bf','5'='#ad7a5b', '6'='turquoise4') +df_clust <- data.frame(across=k_clust_across$cluster) +df_clust$across <- factor(df_clust$across, levels=1:k_center, ordered=T) +rowAnn_left <- HeatmapAnnotation(`Cluster`= df_clust$across, + col = list('Cluster'= annot_cols), + which = 'row', + simple_anno_size = unit(2, 'mm'), + show_legend = FALSE, + show_annotation_name = F) + +# robustCols <- colorRamp2(c(min(tfRobustness), max(tfRobustness)), c("lightgrey","blue")) +# robustCols <- colorRamp2( seq(from = min(tfRobustness), to = max(tfRobustness), length.out = 200),inferno(200)) +robustCols <- colorRamp2(seq(min(tfRobustness), max(tfRobustness), length.out = 100), colorRampPalette(brewer.pal(9, "Blues"))(100)) +rowRobustness <- rowAnnotation( + Robustness = anno_simple( + tfRobustness, + col = robustCols,, + border = TRUE + ), + show_annotation_name = F, + width = unit(1.5, "mm") +) + +# Create a separate legend for robustness +robustLegend <- Legend( + title = "TF Robustness", + col_fun = robustCols, + title_gp = gpar(fontsize = 8, fontface = "plain"), # unbold + labels_gp = gpar(fontsize = 6) # optional: label size +) + +# dirOut <- "/data/MiraldiLab/team/Michael/mCD4T_WaymanFigures/Fig4/networkComparison/" +# dir.create(dirOut, recursive = T, showWarnings = F) + +heatCols <- colorRampPalette(RColorBrewer::brewer.pal(9,"YlOrRd"))(100) +pdf(file.path(dirOut, "Jaccard_Edges_perTF1.pdf"), width = 2.3, height = 4.5, compress = T) +ht1 <- Heatmap(tfMatrix, + name='Jaccard Similarity', + show_row_names = F, + row_split=split_by, + row_gap = unit(0, "mm"), + column_gap = unit(0, "mm"), + border = TRUE, + row_title = NULL, + row_dend_reorder = F, + use_raster = F, + col = heatCols, + # clustering settings + cluster_rows = TRUE, # allow hierarchical clustering + cluster_row_slices = TRUE, # reorder within each k-means cluster + cluster_columns = FALSE, + cluster_column_slices = F, + # column_order = colnames(tfa_z_across), + show_row_dend = FALSE, # show dendrogram within slices + show_column_dend = FALSE, + show_column_names = FALSE, + column_names_side = 'top', + # column_names_rot = 45, + column_title = NULL, + column_split = colnames(tfMatrix), + top_annotation = colAnn, + left_annotation = rowAnn_left, + heatmap_legend_param = list( + direction = "horizontal", + title = "Jaccard", + title_position = "topcenter", + title_gp = gpar(fontsize = 8, fontface = "plain"), # legend title size + labels_gp = gpar(fontsize = 6), # numbers/tick labels size + legend_width = unit(2.5, "cm"), + legend_height = unit(0.5, "cm") + ) + ) +draw(row_mark+ht1+rowRobustness , heatmap_legend_side = "bottom") +dev.off() + +ht2 <- draw(ht1) +row_order_list <- row_order(ht2) +saveRDS(row_order_list, file.path(dirOut, "final_row_order.rds") ) +# Save Legend as PDF +pdf(file.path(dirOut, "TF_Robustness_Legend.pdf"), width = 2, height = 2) +draw(robustLegend) +dev.off() +saveRDS(k_clust_across, file.path(dirOut, "Kmeans_Jaccard_Edges_perTF.rds")) + + +# ---------- Check +clust_df <- df_clust +clust_df$TF <- rownames(clust_df) +colnames(clust_df)[1] <- "cluster" +tfTargetCounts <- tfTargetCounts %>% + left_join(clust_df, by = "TF") + +# Calculate average targets across all methods +tfTargetCounts$avgTargets <- rowMeans(tfTargetCounts[, names(netDataList)]) +tfTargetCounts$medianTargets <- apply(tfTargetCounts[, names(netDataList)],1,median, na.rm = T) +# # Summary statistics by cluster +# clusterSummary <- tfTargetCounts %>% +# group_by(cluster) %>% +# summarise( +# meanTargets = mean(avgTargets), +# medianTargets = median(avgTargets), +# sdTargets = sd(avgTargets), +# nTFs = n() +# ) + +write.table(tfTargetCounts, file.path(dirOut, "TF_avgTargetCounts_byRobustnessCluster.tsv"), quote=F, col.names=NA, row.names=TRUE, sep="\t") + +pAvg <- ggplot(tfTargetCounts, aes(x = cluster, y = avgTargets, fill = cluster)) + + geom_boxplot(outlier.color = "red", outlier.size = 1,outlier.shape = 19, outlier.alpha = 0.7) + + geom_jitter(width = 0.2, size = 0.1) + + labs(title = "", + y = "Average # of Targets", + x = "") + + scale_fill_manual(values = annot_cols) + + theme_minimal() + + theme_bw(base_family = "Helvetica") + + theme( + axis.text.y = element_text(size = 7, color = "black"), + axis.text.x = element_blank(), + axis.ticks.x = element_blank() , + axis.title = element_text(size = 9, color = "black"), + legend.position = "none", + panel.grid.major = element_line(color = "grey80", linewidth = 0.3), + panel.grid.minor = element_line(color = "grey90", linewidth = 0.1) + ) +ggsave(file.path(dirOut, "TF_avgTargetCounts_Distribution.pdf"), pAvg, width = 2.5, height = 2, dpi = 600) + + +pMedian <- ggplot(tfTargetCounts, aes(x = cluster, y = medianTargets, fill = cluster)) + + geom_boxplot(outlier.color = "red", outlier.size = 1,outlier.shape = 19, outlier.alpha = 0.7) + + geom_jitter(width = 0.2, size = 0.1) + + labs(title = "", + y = "Median # of Targets", + x = "") + + scale_fill_manual(values = annot_cols) + + theme_minimal() + + theme_bw(base_family = "Helvetica") + + theme( + axis.text.y = element_text(size = 7, color = "black"), + axis.text.x = element_blank(), + axis.ticks.x = element_blank() , + axis.title = element_text(size = 9, color = "black"), + legend.position = "none", + panel.grid.major = element_line(color = "grey80", linewidth = 0.3), + panel.grid.minor = element_line(color = "grey90", linewidth = 0.1) + ) +ggsave(file.path(dirOut, "TF_medianTargetCounts_Distribution.pdf"), pMedian, width = 2.5, height = 2, dpi = 600) + + +# tf_stats <- lapply(names(netDataList), function(nm){ +# df <- netDataList[[nm]] +# df %>% +# count(!!sym(tfCol)) %>% +# rename(TF = !!sym(tfCol), !!nm := n) +# }) %>% +# reduce(full_join, by = "TF") %>% +# mutate(cluster = k_clust_across$cluster[TF]) +# +# # Which cluster have higher mean degree and which has lower mean degree +# cluster_summary <- tf_stats %>% +# group_by(cluster) %>% +# summarise(across(starts_with("PB") | starts_with("SC") | starts_with("MC") | starts_with("SEA"), +# list(mean = ~mean(.x, na.rm=TRUE), +# sd = ~sd(.x, na.rm=TRUE))), +# n_TFs = n()) +# +# ## Only Lineage TFs +# tfMatrix_lineage <- tfMatrix[intersect(lineageTFs, rownames(tfMatrix)), ] +# k_clust_lineage<- kmeans(tfMatrix_lineage, centers=4,nstart=20, iter.max = 50) +# split_by_lineage <- factor(k_clust_lineage$cluster, levels = c(1:4)) +# +# pdf(file.path(dirOut, "Jaccard_Edges_perTF_lineage.pdf"), width = 2, height = 3, compress = T) +# ht <- Heatmap(tfMatrix_lineage, +# name = 'Jaccard Similarity', +# show_row_names = TRUE, +# row_names_gp = gpar(fontsize = 6), +# row_names_side = "left", # rownames on the left +# row_split = split_by_lineage, +# row_gap = unit(0, "mm"), +# column_gap = unit(0, "mm"), +# border = TRUE, +# row_title = NULL, +# row_dend_reorder = FALSE, +# use_raster = FALSE, +# col = heatCols, +# cluster_rows = TRUE, +# cluster_row_slices = TRUE, +# cluster_columns = FALSE, +# cluster_column_slices = FALSE, +# show_row_dend = FALSE, +# show_column_dend = FALSE, +# show_column_names = FALSE, +# column_names_side = 'top', +# column_title = NULL, +# column_split = colnames(tfMatrix_lineage), +# top_annotation = colAnn, +# heatmap_legend_param = list( +# direction = "horizontal", +# title = "Jaccard Similarity", +# title_position = "topcenter", +# title_gp = gpar(fontsize = 8, fontface = "plain"), # legend title size +# labels_gp = gpar(fontsize = 6), # numbers/tick labels size +# legend_width = unit(2.5, "cm"), +# legend_height = unit(0.5, "cm") +# ) +# ) +# draw(ht, heatmap_legend_side = "bottom") +# dev.off() +# saveRDS(k_clust_lineage, file.path(dirOut, "Kmeans_Jaccard_Edges_perTF_lineage.rds")) + +# --- 4. Aggregated network-pair Jaccard +fun = median +fun_name = "median" # returns "mean" +aggMat <- computeAggregatedPairJaccard(tfMatrix, fun = fun) +plotAggregatedJaccard(aggMat,file.path(dirOut, paste0("AggregatedTF_Jaccard_", fun_name, ".pdf")), fig_width = 2.5, fig_height = 2.5) + +fun = mean +fun_name = "mean" # returns "mean" +aggMat <- computeAggregatedPairJaccard(tfMatrix, fun = fun) +plotAggregatedJaccard(aggMat,file.path(dirOut, paste0("AggregatedTF_Jaccard_", fun_name, ".pdf")), fig_width = 2.5, fig_height = 2.5) + + +# ----- Are top regulators (high-degree TFs) stable across pseudobulk / single-cell / metacell networks. +hubSim <- computeHubOverlapHeatmap( + netDataList = netDataList, tfCol = tfCol, networkNames = networkNames, + dirOut = dirOut, topN = 50, + metric = "jaccard", # or "overlap" + fontsize = 9, fontsize_number = 7, fig_width = 2.5, fig_height = 2.5 +) + +hubSim <- computeHubOverlapHeatmap( + netDataList = netDataList, tfCol = tfCol, networkNames = networkNames, + dirOut = dirOut, topN = 50, + metric = "overlap", # or "jaccard" + fontsize = 9, fontsize_number = 7, fig_width = 2.5, fig_height = 2.5 +) + +# ──────────────────────────────────────────────────────────────────────────────────────────── +# PART THREE: Histogram distributions of: + # A: Targets per TF: # times each gene is targeted (number of TFs regulating it). + # B: TFs per Target: # times each TF is a regulator (number of genes it controls) + # C: Box-Plot of Stability of the top N low and high in degree genes + + #NOTE: Each Figure is saved in the same path as the network being evaluated +# ──────────────────────────────────────────────────────────────────────────────────────────── +# Generate both plots for each network +for (typeName in names(netDataList)) { + data <- netDataList[[typeName]] + basePath <- dirname(netFiles[[typeName]]) + + # Plot 1: # Targets per TF + # tfTargetCounts <- table(data[[tfCol]]) + # dfTF <- data.frame(TF = names(tfTargetCounts), targetCount = as.integer(tfTargetCounts)) + dfTF <- tfTargetCounts[, c("TF", tfCol), drop = F] + colnames(dfTF)[2] <- "targetCount" + + p1 <- ggplot(dfTF, aes(x = targetCount)) + + geom_histogram(binwidth = 1, boundary = 0.5, fill = "#0072B2", color = "black", alpha = 0.8) + + # scale_x_continuous(breaks = seq(0, max(dfTF$targetCount), 1)) + + labs(title = paste("Distribution of # Targets per TF -", typeName), + x = "# Targets", y = "# TFs") + + theme_bw(base_family = "Helvetica") + + theme( + panel.grid.major.y = element_line(color = "grey80", linewidth = 0.3), + panel.grid.minor = element_blank(), + # axis.line = element_line(color = "black", linewidth = 0.4), + axis.text.x = element_text(size = 7), + axis.text.y = element_text(size = 7), + axis.title = element_text(size = 9), + plot.margin = margin(5, 5, 5, 5), + # panel.background = element_rect("black", fill = NA) + ) + + ggsave(file.path(basePath, paste0("TargetCountPerTF_", typeName, ".pdf")), + plot = p1, width = 3, height = 3) + + # Plot 2: # TFs per Target + targetTFCounts <- table(data[[targetCol]]) + dfTarget<- data.frame(Target = names(targetTFCounts), tfCount = as.integer(targetTFCounts)) + dfTargetSorted <- dfTarget[order(dfTarget$tfCount), ] + + p2 <- ggplot(dfTarget, aes(x = tfCount)) + + geom_histogram(binwidth = 1, boundary = 0.5, fill = "blue", color = "black", alpha = 0.7) + + # scale_x_continuous(breaks = seq(0, max(dfTarget$tfCount), 1)) + + labs(title = paste("Distribution of # TFs per Target -", typeName), + x = "# TFs", y = "# Targets") + + theme_bw(base_family = "Helvetica") + + theme( + panel.grid.major.y = element_line(color = "grey80", linewidth = 0.3), + panel.grid.minor = element_blank(), + # axis.line = element_line(color = "black", linewidth = 0.4), + axis.text.x = element_text(size = 7), + axis.text.y = element_text(size = 7), + axis.title = element_text(size = 9), + plot.margin = margin(5, 5, 5, 5), + # panel.background = element_rect("black", fill = NA) + ) + + ggsave(file.path(basePath, paste0("TFCountPerTarget_", typeName, ".pdf")), + plot = p2, width = 3, height = 3) + + # Select top N low and high TF targets + lowTargs <- dfTargetSorted$Target[1:nSelect] + highTargs <- tail(dfTargetSorted, nSelect)$Target + + stabilityDF <- data[data[[targetCol]] %in% c(lowTargs, highTargs), c(targetCol, stabCol)] + # Add Group column based on whether the target is in lowTargs or highTargs + stabilityDF$Group <- ifelse(stabilityDF[[targetCol]] %in% lowTargs, + "Low TFs per Target", + "High TFs per Target") + stabilityDF <- merge(stabilityDF, dfTargetSorted, by.x = targetCol, by.y = "Target") + stabilityDF$targetLabel <- paste0(stabilityDF[[targetCol]], "(", stabilityDF$tfCount, ")") + # Keep plotting order +# stabilityDF$targetLabel <- factor(stabilityDF$targetLabel, +# levels = unique(stabilityDF$targetLabel)) + # Order based on tfCount directly + stabilityDF$targetLabel <- factor( + stabilityDF$targetLabel, + levels = unique(stabilityDF$targetLabel[order(stabilityDF$tfCount)]) + ) + + # Plot: boxplot per target + p3 <- ggplot(stabilityDF, aes(x = targetLabel, y = Stability, fill = Group)) + + geom_boxplot(outlier.size = 0.5, alpha = 0.8) + + labs(title = paste("Per-Target Stability Distribution -", typeName), + x = "Number of TFs per Target (TF count)", y = "Stability") + + theme_minimal() + + theme_bw(base_family = "Helvetica") + + theme( + panel.grid.major.y = element_line(color = "grey80", linewidth = 0.3), + panel.grid.minor = element_blank(), + # axis.line = element_line(color = "black", linewidth = 0.4), + axis.text.x = element_text(size = 7, angle = 90, vjust = 0.5), + axis.text.y = element_text(size = 7), + axis.title = element_text(size = 9), + plot.margin = margin(5, 5, 5, 5), + legend.position = "top" + # panel.background = element_rect("black", fill = NA) + ) + + # Save + ggsave(file.path(basePath, paste0("Top", nSelect, "HighorLow_inDegreeGenes_Boxplot", typeName, ".pdf")), + plot = p3, width = 12, height = 5) + +} + + + + +# + + + +# +# library(reshape2) +# library(ggplot2) +# +# # Convert to long format for ggplot +# df_long <- melt(tfa_by_celltype) +# colnames(df_long) <- c("TF", "CellType", "TFA") +# +# # Add a flag for lineage TFs +# df_long$Lineage <- mapply(function(tf, ct) { +# if(tf %in% lineage_TFs[[ct]]) "Lineage" else "Other" +# }, df_long$TF, df_long$CellType) +# +# # Boxplot or violin plot +# ggplot(df_long, aes(x = CellType, y = TFA, fill = Lineage)) + +# geom_violin(alpha = 0.6) + +# geom_boxplot(width=0.1, outlier.shape=NA) + +# scale_fill_manual(values = c("Lineage"="red", "Other"="grey")) + +# theme_minimal() + +# theme(axis.text.x = element_text(angle=45, hjust=1)) + +# labs(title="Lineage TF enrichment in TFA", y="Mean TFA") +# +# +# lineage_TFs <- list( +# Tfh10 = c("Bcl6","Maf","Batf"), +# Tfh_Int = c("Bcl6","Maf","Batf"), +# Tfh = c("Bcl6","Maf"), +# Tfr = c("Foxp3","Bcl6"), +# cTreg = c("Foxp3","Ikzf2"), +# eTreg = c("Foxp3","Ikzf2"), +# rTreg = c("Foxp3","Ikzf2"), +# Treg_Rorc = c("Foxp3","Rorc"), +# Th17 = c("Rorc","Batf","Stat3"), +# Th1 = c("Tbx21","Stat4"), +# CTL_Prdm1 = c("Tbx21","Eomes","Prdm1"), +# CTL_Bcl6 = c("Tbx21","Eomes","Prdm1"), +# TEM = c("Eomes","Tcf7","Klf2"), +# TCM = c("Eomes","Tcf7","Klf2") +# ) +# + +# lineage_tfs_list <- list( +# TCM = c("TCF7", "LEF1", "ID3", "KLF2"), +# TEM = c("PRDM1", "ZEB2", "RUNX3", "EOMES"), +# Th1 = c("TBX21", "STAT1", "STAT4", "RUNX3", "HIF1A"), +# Th17 = c("RORC", "RORA", "STAT3", "IRF4", "BATF"), +# Treg = c("FOXP3", "IKZF2", "IKZF4", "STAT5"), +# Naive = c("TCF7", "LEF1", "KLF2", "FOXO1"), +# Th2 = c("GATA3", "STAT5", "STAT6", "IRF4"), +# CTL = c("EOMES", "TBX21", "RUNX3", "ZEB2", "PRDM1"), +# MHCII = c("CIITA", "RFX5", "RFXAP", "RFXANK", "NLRC5") +# ) + +# +# +# # Distribution per F1. +# # Also show F1 score based on heatmap +# ggplot(pr_df, aes(x = network, y = F1)) + +# geom_boxplot(outlier.size = 0.8) + +# geom_jitter(width = 0.15, alpha = 0.6, size = 1) + +# theme_minimal() + +# theme(axis.text.x = element_text(angle=45, hjust=1)) + +# ylab("F1 score") + xlab("Representation") diff --git a/evaluation/R/histogramConfidences.R b/evaluation/R/histogramConfidences.R new file mode 100755 index 0000000..c233cf0 --- /dev/null +++ b/evaluation/R/histogramConfidences.R @@ -0,0 +1,194 @@ +# Plot weights/confidences +rm(list=ls()) +library(ggplot2) + +# Custom function to format y-axis labels +custom_labels <- function(x) { + # Keep 0 as is, scale other values by 1000 and add "K" + labels <- ifelse(x == 0, "0", paste0(scales::number(x / 1000, accuracy = 0.1), "K")) + return(labels) +} + +histogramConfidencesDir <- function(currNetDirs, breaks) { + # Ensure breaks length matches currNetDirs length + if (length(currNetDirs) != length(breaks)) { + stop("Length of breaks must match length of currNetDirs.") + } + + for (ix in seq_along(currNetDirs)) { + currNetDir <- currNetDirs[ix] + subfolders <- list.dirs(currNetDir, recursive = FALSE) + target_folders <- subfolders[basename(subfolders) %in% c("TFA", "TFmRNA")] + + # Load and plot individually + for (subfolder in target_folders) { + # subfolder = target_folders[jx] + file_path <- file.path(subfolder, "edges_subset.txt") + if (file.exists(file_path)) { + data <- read.table(file_path, header = TRUE, sep = "\t") + + # Create individual histogram + conf <- data.frame(confidence = floor(data$Stability) / max(floor(data$Stability))) + p <- ggplot((conf), aes(x = confidence)) + + geom_histogram(bins = breaks[ix], fill = "blue", color = "black", alpha = 0.9) + + labs(title = basename(subfolder), x = "Confidence", y = "Frequency") + + theme_bw() + + theme( + axis.title = element_text(size = 16, color = "black"), + axis.text = element_text(size = 14, color = "black"), + plot.title = element_text(size = 20,, color = "black", hjust = 0.5) + ) + scale_y_continuous(labels = custom_labels)#+ scale_y_continuous(labels = label_number(scale = 1e-3, suffix = "K")) + + # Save individual plot + hist_file <- file.path(subfolder, paste0("confidence_distribution_", breaks[ix], ".png")) + ggsave(hist_file, plot = p, width = 6, height = 5, dpi = 600) + } else { + message("File not found: ", file_path) + } + } + + } + +} + + +histogramConfidencesData <- function(currNetFiles, breaks) { + # Ensure breaks length matches currNetFiles length + if (length(currNetFiles) > length(breaks)) { + stop("Length of breaks must match or be greater length of currNetFiles.") + } + + # Loop over the directories in currNetFiles + for (ix in seq_along(currNetFiles)) { + currNetFile <- currNetFiles[ix] # Full path to edges_subset.txt + + # Check if the file exists + if (file.exists(currNetFile)) { + data <- read.table(currNetFile, header = TRUE, sep = "\t") + + # Create individual histogram + conf <- data.frame(confidence = floor(data$Stability) / max(floor(data$Stability))) + p <- ggplot(conf, aes(x = confidence)) + + geom_histogram(bins = breaks[ix], fill = "blue", color = "black", alpha = 0.9) + + labs(title = basename(dirname(currNetFile)), x = "Confidence", y = "Frequency") + + theme_bw() + + theme( + axis.title = element_text(size = 16, color = "black"), + axis.text = element_text(size = 14, color = "black"), + plot.title = element_text(size = 20, color = "black", hjust = 0.5) + ) + scale_y_continuous(labels = custom_labels) # Custom y-axis labels + + # Save individual plot in the same directory + hist_file <- file.path(dirname(currNetFile), paste0("confidence_distribution_", breaks[ix], ".png")) + ggsave(hist_file, plot = p, width = 6, height = 5, dpi = 600) + } else { + message("File not found: ", currNetFile) + } + } +} + +currNetFile <- "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/1KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT20/TFmRNA/edges_subset.txt" + +conf <- data.frame(confidence = floor(data$Stability)) +breaks <- max(conf) +p <- ggplot(conf, aes(x = confidence)) + + geom_histogram(bins = breaks, fill = "blue", color = "black", alpha = 0.9) + + labs(title = basename(dirname(currNetFile)), x = "Confidence", y = "Frequency") + + theme_bw() + + theme( + axis.title = element_text(size = 16, color = "black"), + axis.text = element_text(size = 14, color = "black"), + plot.title = element_text(size = 20, color = "black", hjust = 0.5) + ) + + hist_file <- file.path(dirname(currNetFile), paste0("testConf_", breaks, ".png")) + ggsave(hist_file, plot = p, width = 6, height = 5, dpi = 600) + +# USAGE + + +currNetDirs <- c( + # ---Pseudobulk Inferelator + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63", + + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/Inferelator/ATAC_ChIPprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63", + + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/Inferelator/ATAC_KOprior/Bulk/lambda0p25_80totSS_20tfsPerGene_subsamplePCT63", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda0p5_80totSS_20tfsPerGene_subsamplePCT63", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/Bulk/lambda1p0_80totSS_20tfsPerGene_subsamplePCT63", + + # --- single cell inferelator + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/lambda0p25_220totSS_20tfsPerGene_subsamplePCT10", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/lambda0p5_220totSS_20tfsPerGene_subsamplePCT10", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/lambda1p0_220totSS_20tfsPerGene_subsamplePCT10", + + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/SC/lambda0p25_220totSS_20tfsPerGene_subsamplePCT10", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/SC/lambda0p5_220totSS_20tfsPerGene_subsamplePCT10", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_ChIPprior/SC/lambda1p0_220totSS_20tfsPerGene_subsamplePCT10", + + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/SC/lambda0p25_220totSS_20tfsPerGene_subsamplePCT10", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/SC/lambda0p5_220totSS_20tfsPerGene_subsamplePCT10", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATAC_KOprior/SC/lambda1p0_220totSS_20tfsPerGene_subsamplePCT10", + + # --- single cell downsampled + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/1KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT27", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/10KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT27", + "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/30KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT13" +) + + +# Extract subsample fractions +# subsample_fracs <- sapply(currNetDirs, function(x) sub(".*subsampleFrac_([0-9p]+).*", "\\1", x)) +# subsample_fracs <- gsub("p", ".", subsample_fracs) # Convert 'p' to '.' for numeric comparison +# unique_fracs <- unique(subsample_fracs) + +breaks = c(rep(150, 9), rep(500, 12)) + + + +netFiles = c( + "1K" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/1KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT27/TFA/edges_subset.txt", + "10K" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/10KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT27/TFA/edges_subset.txt", + "30K" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/30KCells/lambda0p5_220totSS_20tfsPerGene_subsamplePCT13/TFA/edges_subset.txt", + "77K" = "/data/miraldiNB/Michael/mCD4T_Wayman/Inferelator/ATACprior/SC/lambda0p5_220totSS_20tfsPerGene_subsamplePCT10/TFA/edges_subset.txt" + ) + +histogramConfidencesStacked <- function(currNetFiles, breaks, dirOut, saveName) { + + allData <- data.frame(confidence = numeric(), network = character()) + + for (ix in seq_along(currNetFiles)) { + currNetFile <- currNetFiles[ix] # Full path to edges_subset.txt + currNetName <- names(currNetFile) + if (file.exists(currNetFile)) { + data <- read.table(currNetFile, header = TRUE, sep = "\t") + # Compute absolute value of signedQuantile and create a temporary data frame. + conf <- data.frame(confidence = floor(data$Stability) / max(floor(data$Stability))) + tempDf <- data.frame(confidence = conf, + network = currNetName) + allData <- rbind(allData, tempDf) + } else { + message("File not found: ", currNetFile) + } + } + + p <- ggplot(allData, aes(x = confidence, color = network)) + + # geom_freqpoly(bins = breaks[1], size = 1) + + geom_histogram(bins = breaks[1], alpha = 0.6, position = "identity", color = "black") + + labs(x = "Confidence", y = "Frequency") + + theme_bw() + + theme( + axis.title = element_text(size = 16, color = "black"), + axis.text = element_text(size = 14, color = "black") + ) + histFile <- file.path(dirOut, paste0(saveName, ".pdf")) + ggsave(histFile, plot = p, width = 6, height = 5, dpi = 600) + +} + + +histogramConfidencesStacked(netFiles, breaks = 300, dirOut = "/data/miraldiNB/Michael/mCD4T_Wayman/Figures", saveName = "confidencesTest") \ No newline at end of file diff --git a/evaluation/R/saveNormCountsArrowFIle.R b/evaluation/R/saveNormCountsArrowFIle.R new file mode 100755 index 0000000..ddf647a --- /dev/null +++ b/evaluation/R/saveNormCountsArrowFIle.R @@ -0,0 +1,35 @@ +rm(list=ls()) +options(stringsAsFactors=FALSE) +set.seed(42) +suppressPackageStartupMessages({ + library(Signac) + library(Seurat) + library(arrow) +}) + +# make sure you load or have the latest version of curl installed. + +# make sure you load or have the latest version of curl installed. + +rds_path <- "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/annotation_scrna_final/obj_Tfh10_RNA_annotated.rds" +obj <- readRDS(rds_path) + +# Process File for scGRN using Inferelator. +obj <- NormalizeData(obj) +norm_counts <- as.matrix(GetAssayData(obj, layer='data')) +# write_parquet(as.data.frame(norm_counts), "/data/miraldiNB/Michael/GRN_Benchmark/Data/Tfh10_scRNA_logNorm_Counts.parquet") + +norm_counts <- as.data.frame(norm_counts) +norm_counts$Genes <- rownames(norm_counts) +norm_counts <- norm_counts[, c("Genes", setdiff(colnames(norm_counts), "Genes"))] # Reorder the columns to make 'Genes' the first column + +# Save as an arrow file +write_feather(norm_counts, "/data/miraldiNB/Michael/GRN_Benchmark/Data/Tfh10_scRNA_logNorm_Counts.arrow") + +# write_feather(norm_counts, "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/scNormCounts.arrow") + +# feather_file <- read_feather("/data/miraldiNB/Katko/Projects/Barski_CD4_Multiome/Outs/Pseudobulk/RNA2/SC_counts.feather") +# write_ipc_file(feather_file, "/data/miraldiNB/Katko/Projects/Barski_CD4_Multiome/Outs/Pseudobulk/RNA2/SC_counts.arrow") + + +rds_path <- "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/ObjFiltered.rds" \ No newline at end of file diff --git a/examples/interactive_pipeline.jl b/examples/interactive_pipeline.jl new file mode 100644 index 0000000..03fcfe0 --- /dev/null +++ b/examples/interactive_pipeline.jl @@ -0,0 +1,172 @@ +# ============================================================================= +# interactive_pipeline.jl — Step-by-step GRN inference using the public API +# +# What: +# Runs the full 6-step mLASSO-StARS pipeline interactively. Each step calls +# one high-level API function. Intermediate structs remain in the REPL between +# steps so you can inspect, plot, or debug before continuing. +# +# Required inputs: +# geneExprFile — gene expression matrix (.txt tab-delimited or .arrow) +# targFile — target gene list (.txt, one gene per line) +# regFile — potential regulator (TF) list (.txt, one TF per line) +# priorFile — prior network matrix (.tsv, sparse TF × gene) +# priorFilePenalties — prior(s) used to set LASSO penalties (same or different) +# +# Expected outputs (written under outputDir//): +# TFA/edges.tsv — GRN inferred with TFA predictors +# TFmRNA/edges.tsv — GRN inferred with TF mRNA predictors +# Combined/combined_*.tsv — consensus network (max/mean/min aggregation) +# Combined/TFA/ — refined TFA network re-estimated from consensus prior +# +# Usage: +# Step-by-step in a Julia REPL (recommended for interactive analysis) +# julia examples/interactive_pipeline.jl +# +# Compare with: +# interactive_pipeline_dev.jl — same steps, module-qualified internal calls +# run_pipeline.jl — same pipeline wrapped in a callable function +# +# Installation: +# pkg> dev /path/to/InferelatorJL # local development +# pkg> add "https://github.com/org/InferelatorJL.jl" # published release +# Tip: load Revise before this file to pick up source edits without restarting: +# using Revise; using InferelatorJL +# ============================================================================= +using Revise +using InferelatorJL + +# ============================================================================= +# Configuration — edit these paths and parameters for your dataset +# ============================================================================= + +out +outputDir = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/Inferelator/test3" +tfaOptions = ["", "TFmRNA"] # "" → TFA mode, "TFmRNA" → mRNA mode +totSS = 80 +bstarsTotSS = 5 +subsampleFrac = 0.68 +minLambda = 0.01 +maxLambda = 0.5 +totLambdasBstars = 20 +totLambdas = 40 +targetInstability = 0.05 +meanEdgesPerGene = 20 +correlationWeight = 1 +minTargets = 3 +edgeSS = 0 +lambdaBias = [0.5] +instabilityLevel = "Network" # "Network" or "Gene" +useMeanEdgesPerGeneMode = true +combineOpt = "max" # "max", "mean", or "min" +zScoreTFA = true # z-score targets before TFA estimation +zScoreLASSO = true # z-score targets before LASSO regression + +geneExprFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/pseudobulk/pseudobulk_scrna/CellType/Age/Factor1/min0.25M/counts_Tfh10_AgeCellType_pseudobulk_scrna_vst_batch_downsample_0.25M.txt" +targFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/target_genes/gene_targ_Tfh10_SigPct5Log2FC0p58FDR5.txt" +regFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/pot_regs/TF_Tfh10_SigPct5Log2FC0p58FDR5_final.txt" + +priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv" +priorFilePenalties = ["/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"] +tfaGeneFile = "" # optional: restrict TFA estimation to a gene subset + +# --- Build output directory name (encodes key run parameters) +subsamplePct = subsampleFrac * 100 +subsampleStr = isinteger(subsamplePct) ? string(Int(subsamplePct)) : replace(string(subsamplePct), "." => "p") +lambdaStr = join(replace.(string.(lambdaBias), "." => "p"), "_") +networkBaseName = lowercase(instabilityLevel) * "Lambda" * lambdaStr * "_" * string(totSS) * "totSS_" * + string(meanEdgesPerGene) * "tfsPerGene_" * "subsamplePCT" * subsampleStr +dirOut = joinpath(outputDir, networkBaseName) +mkpath(dirOut) + +@info "Configuration" outputDir=dirOut geneExprFile priorFile lambdaBias subsampleFrac + +# ============================================================================= +# STEP 1 — Load and filter expression data +# ============================================================================= +# Loads expression matrix, filters to target genes and potential regulators. +# Inspect: fieldnames(GeneExpressionData), size(data.expressionMat) + +data = loadData(geneExprFile, targFile, regFile; + tfaGeneFile = tfaGeneFile, + epsilon = 0.01) + +# ============================================================================= +# STEP 2 + 3 — Merge degenerate TFs, process prior, estimate TFA +# ============================================================================= +# Merges TFs with identical binding profiles, builds the prior matrix, +# and estimates TF activity (TFA) via least-squares. +# Inspect: size(priorData.medTfas), priorData.tfNames + +priorData, mergedTFs = loadPrior(data, priorFile; minTargets = minTargets) + +estimateTFA(priorData, data; + edgeSS = edgeSS, + zScoreTFA = zScoreTFA, + outputDir = dirOut) + +# ============================================================================= +# STEP 4 — Build GRN for each predictor mode +# ============================================================================= +# Runs mLASSO-StARS for TFA predictors and TF mRNA predictors separately. +# Outputs instability curves and ranked edge lists to TFA/ and TFmRNA/. + +for (tfaMode, modeLabel) in [(true, "TFA"), (false, "TFmRNA")] + instabilitiesDir = joinpath(dirOut, modeLabel) + mkpath(instabilitiesDir) + + @info "Building network" mode=modeLabel + + buildNetwork(data, priorData; + tfaMode = tfaMode, + priorFilePenalties = priorFilePenalties, + lambdaBias = lambdaBias, + totSS = totSS, + bstarsTotSS = bstarsTotSS, + subsampleFrac = subsampleFrac, + minLambda = minLambda, + maxLambda = maxLambda, + totLambdasBstars = totLambdasBstars, + totLambdas = totLambdas, + targetInstability = targetInstability, + meanEdgesPerGene = meanEdgesPerGene, + correlationWeight = correlationWeight, + instabilityLevel = instabilityLevel, + useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode, + zScoreLASSO = zScoreLASSO, + outputDir = instabilitiesDir) +end + +# ============================================================================= +# STEP 5 — Aggregate TFA + mRNA networks into a consensus network +# ============================================================================= +# Combines the two edge lists using max/mean/min stability per (TF, gene) pair. +# Outputs combined_.tsv and combined__sp.tsv to Combined/. + +combinedNetDir = joinpath(dirOut, "Combined") +aggregateNetworks( + [joinpath(dirOut, "TFA", "edges.tsv"), + joinpath(dirOut, "TFmRNA", "edges.tsv")]; + method = Symbol(combineOpt), + meanEdgesPerGene = meanEdgesPerGene, + useMeanEdgesPerGene = useMeanEdgesPerGeneMode, + outputDir = combinedNetDir) + +# ============================================================================= +# STEP 6 — Re-estimate TFA using the consensus network as a refined prior +# ============================================================================= +# Uses the combined network as a new prior to re-estimate TF activity, then +# re-runs mLASSO-StARS. Outputs go to Combined/TFA/. + +netsCombinedSparse = joinpath(combinedNetDir, "combined_" * combineOpt * "_sp.tsv") +refineTFA(netsCombinedSparse, data, mergedTFs; + tfaGeneFile = tfaGeneFile, + edgeSS = edgeSS, + minTargets = minTargets, + zScoreTFA = zScoreTFA, + exprFile = geneExprFile, + targFile = targFile, + regFile = regFile, + outputDir = combinedNetDir) + +@info "Pipeline complete" outputDir=dirOut diff --git a/examples/interactive_pipeline_dev.jl b/examples/interactive_pipeline_dev.jl new file mode 100644 index 0000000..f3e6fc7 --- /dev/null +++ b/examples/interactive_pipeline_dev.jl @@ -0,0 +1,199 @@ +# ============================================================================= +# interactive_pipeline_dev.jl — Step-by-step GRN inference via internal functions +# +# What: +# Identical pipeline to interactive_pipeline.jl but calls internal functions +# directly using the InferelatorJL. module prefix. Use this when you need +# finer control over individual steps (e.g., inspecting intermediate matrices, +# swapping in a custom subfunction, or debugging a specific stage). +# All internal functions remain accessible via module-qualified calls even +# though they are not exported from the package. +# +# Required inputs: +# geneExprFile — gene expression matrix (.txt tab-delimited or .arrow) +# targFile — target gene list (.txt, one gene per line) +# regFile — potential regulator (TF) list (.txt, one TF per line) +# priorFile — prior network matrix (.tsv, sparse TF × gene) +# priorFilePenalties — prior(s) used to set LASSO penalties (same or different) +# +# Expected outputs (written under outputDir//): +# TFA/edges.tsv — GRN inferred with TFA predictors +# TFmRNA/edges.tsv — GRN inferred with TF mRNA predictors +# Combined/combined_*.tsv — consensus network (max/mean/min aggregation) +# Combined/TFA/ — refined TFA network re-estimated from consensus prior +# +# Usage: +# Step-by-step in a Julia REPL (recommended for inspecting intermediate state) +# julia examples/interactive_pipeline_dev.jl +# +# Compare with: +# interactive_pipeline.jl — same steps via high-level public API +# run_pipeline_dev.jl — same internal calls wrapped in a callable function +# +# Installation: +# pkg> dev /path/to/InferelatorJL # local development +# pkg> add "https://github.com/org/InferelatorJL.jl" # published release +# Tip: load Revise before this file to pick up source edits without restarting: +# using Revise; using InferelatorJL +# ============================================================================= +using Revise +using InferelatorJL + +# ============================================================================= +# Configuration — edit these paths and parameters for your dataset +# ============================================================================= + +outputDir = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/Inferelator/test" +tfaOptions = ["", "TFmRNA"] # "" → TFA mode, "TFmRNA" → mRNA mode +totSS = 80 +bstarsTotSS = 5 +subsampleFrac = 0.68 +minLambda = 0.01 +maxLambda = 0.5 +totLambdasBstars = 20 +totLambdas = 40 +targetInstability = 0.05 +meanEdgesPerGene = 20 +correlationWeight = 1 +minTargets = 3 +edgeSS = 0 +lambdaBias = [0.5] +instabilityLevel = "Network" # "Network" or "Gene" +useMeanEdgesPerGeneMode = true +combineOpt = "max" # "max", "mean", or "min" +zScoreTFA = true # z-score targets before TFA estimation +zScoreLASSO = true # z-score targets before LASSO regression + +geneExprFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/pseudobulk/pseudobulk_scrna/CellType/Age/Factor1/min0.25M/counts_Tfh10_AgeCellType_pseudobulk_scrna_vst_batch_downsample_0.25M.txt" +targFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/target_genes/gene_targ_Tfh10_SigPct5Log2FC0p58FDR5.txt" +regFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/pot_regs/TF_Tfh10_SigPct5Log2FC0p58FDR5_final.txt" + +priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv" +priorFilePenalties = ["/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"] +tfaGeneFile = "" # optional: restrict TFA estimation to a gene subset + +# --- Build output directory name (encodes key run parameters) +subsamplePct = subsampleFrac * 100 +subsampleStr = isinteger(subsamplePct) ? string(Int(subsamplePct)) : replace(string(subsamplePct), "." => "p") +lambdaStr = join(replace.(string.(lambdaBias), "." => "p"), "_") +networkBaseName = lowercase(instabilityLevel) * "Lambda" * lambdaStr * "_" * string(totSS) * "totSS_" * + string(meanEdgesPerGene) * "tfsPerGene_" * "subsamplePCT" * subsampleStr +dirOut = joinpath(outputDir, networkBaseName) +mkpath(dirOut) + +@info "Configuration" outputDir=dirOut geneExprFile priorFile lambdaBias subsampleFrac + +# ============================================================================= +# STEP 1 — Load and filter expression data +# ============================================================================= +# Loads expression matrix, filters to target genes and potential regulators. +# Inspect: fieldnames(GeneExpressionData), size(data.expressionMat), data.geneNames + +data = GeneExpressionData() +InferelatorJL.loadExpressionData!(data, geneExprFile) +InferelatorJL.loadAndFilterTargetGenes!(data, targFile; epsilon = 0.01) +InferelatorJL.loadPotentialRegulators!(data, regFile) +InferelatorJL.processTFAGenes!(data, tfaGeneFile; outputDir = dirOut) + +# ============================================================================= +# STEP 2 — Merge degenerate TFs +# ============================================================================= +# Identifies TFs with identical binding profiles in the prior and merges them. +# Inspect: mergedTFsData.mergedTFs, mergedTFsData.tfNames + +mergedTFsData = mergedTFsResult() +InferelatorJL.mergeDegenerateTFs(mergedTFsData, priorFile; fileFormat = 2) + +# ============================================================================= +# STEP 3 — Process prior and estimate TFA +# ============================================================================= +# Builds the filtered prior matrix and estimates TF activity via least-squares. +# Inspect: tfaData.priorMat, tfaData.medTfas, size(tfaData.medTfas) + +tfaData = PriorTFAData() +InferelatorJL.processPriorFile!(tfaData, data, priorFile; + mergedTFsData = mergedTFsData, + minTargets = minTargets) +InferelatorJL.calculateTFA!(tfaData, data; + edgeSS = edgeSS, + zTarget = zScoreTFA, + outputDir = dirOut) + +# ============================================================================= +# STEP 4 — Build GRN for each predictor mode +# ============================================================================= +# Runs subsampling, warm-start lambda selection, instability estimation, +# lambda selection, and edge ranking for each predictor mode. + +for tfaOpt in tfaOptions + instabilitiesDir = tfaOpt == "" ? joinpath(dirOut, "TFA") : joinpath(dirOut, "TFmRNA") + mkpath(instabilitiesDir) + + @info "Building network" tfaOpt=(isempty(tfaOpt) ? "TFA" : tfaOpt) + + grnData = GrnData() + InferelatorJL.preparePredictorMat!(grnData, data, tfaData; tfaOpt = tfaOpt) + InferelatorJL.preparePenaltyMatrix!(data, grnData; + priorFilePenalties = priorFilePenalties, + lambdaBias = lambdaBias, + tfaOpt = tfaOpt) + InferelatorJL.constructSubsamples(data, grnData; totSS = bstarsTotSS, subsampleFrac = subsampleFrac) + InferelatorJL.bstarsWarmStart(data, tfaData, grnData; + minLambda = minLambda, + maxLambda = maxLambda, + totLambdasBstars = totLambdasBstars, + targetInstability = targetInstability, + zTarget = zScoreLASSO) + InferelatorJL.constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = subsampleFrac) + InferelatorJL.bstartsEstimateInstability(grnData; + totLambdas = totLambdas, + instabilityLevel = instabilityLevel, + zTarget = zScoreLASSO, + outputDir = instabilitiesDir) + + buildGrn = BuildGrn() + InferelatorJL.chooseLambda!(grnData, buildGrn; + instabilityLevel = instabilityLevel, + targetInstability = targetInstability) + InferelatorJL.rankEdges!(data, tfaData, grnData, buildGrn; + useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode, + meanEdgesPerGene = meanEdgesPerGene, + correlationWeight = correlationWeight, + outputDir = instabilitiesDir) + writeNetworkTable!(buildGrn; outputDir = instabilitiesDir) +end + +# ============================================================================= +# STEP 5 — Aggregate TFA + mRNA networks into a consensus network +# ============================================================================= +# Combines edge lists using max/mean/min stability per (TF, gene) pair. + +combinedNetDir = joinpath(dirOut, "Combined") +nets2combine = [ + joinpath(dirOut, "TFA", "edges.tsv"), + joinpath(dirOut, "TFmRNA", "edges.tsv") +] +InferelatorJL.aggregateNetworks(nets2combine; + method = Symbol(combineOpt), + meanEdgesPerGene = meanEdgesPerGene, + useMeanEdgesPerGene = useMeanEdgesPerGeneMode, + outputDir = combinedNetDir) + +# ============================================================================= +# STEP 6 — Re-estimate TFA using the consensus network as a refined prior +# ============================================================================= +# Uses combined network as new prior, re-estimates TFA, re-runs mLASSO-StARS. + +netsCombinedSparse = joinpath(combinedNetDir, "combined_" * combineOpt * "_sp.tsv") +InferelatorJL.refineTFA(data, mergedTFsData; + priorFile = netsCombinedSparse, + tfaGeneFile = tfaGeneFile, + edgeSS = edgeSS, + minTargets = minTargets, + zTarget = zScoreTFA, + geneExprFile = geneExprFile, + targFile = targFile, + regFile = regFile, + outputDir = combinedNetDir) + +@info "Pipeline complete" outputDir=dirOut diff --git a/examples/plotPR.jl b/examples/plotPR.jl new file mode 100644 index 0000000..5d8909a --- /dev/null +++ b/examples/plotPR.jl @@ -0,0 +1,221 @@ +# ============================================================================= +# plotPR.jl — Evaluate GRN predictions against gold standards and plot PR curves +# +# What: +# Evaluates one or more inferred GRNs against gold-standard interaction sets +# by computing precision-recall (PR) and ROC metrics, then generating +# publication-quality PR curve plots. Supports multiple networks and +# multiple gold standards in a single run. Optionally generates per-TF +# PR curves and AUPR bar plots. +# +# Required inputs: +# outNetFiles — inferred GRN file(s) as legend label → file path dict +# gsParam — gold-standard file(s) as name → file path dict +# prTargGeneFile — target gene list used to restrict evaluation universe +# (set to "" to use all genes in the network) +# gsRegsFile — regulator list to restrict evaluation to shared TFs +# (set to "" to use all regulators) +# +# Expected outputs (written relative to each network file's directory): +# PR_noPotRegs// — PR data files (if gsRegsFile = "") +# PR_withPotRegs// — PR data files (if gsRegsFile is set) +# dirOutPlot/_*.png — PR curve plots and optional AUPR bar plots +# +# Usage: +# julia examples/plotPR.jl +# or configure the USER CONFIG section and run step-by-step in the REPL +# +# Installation: +# pkg> dev /path/to/InferelatorJL # local development +# pkg> add "https://github.com/org/InferelatorJL.jl" # published release +# Tip: load Revise before this file to pick up source edits without restarting: +# using Revise; using InferelatorJL +# ============================================================================= + +using InferelatorJL +import InferelatorJL: computePR, plotPRCurves, plotAUPR, loadPRData + +using OrderedCollections + +# ══════════════════════════════════════════════════════════════════════════════ +# USER CONFIG — edit this section +# ══════════════════════════════════════════════════════════════════════════════ + +# Output directory for plots +dirOutPlot = "/data/miraldiNB/Michael/projects/goldStandard/Human/fdr0p01_vs_IDR" + +# Base name for saved figures (set to "" to use gold-standard name only) +figBaseName = "ALL_5KB_sharedTF_Target" + +# Network files to compare: legend label => file path +outNetFiles = OrderedDict( + "FDR0p01" => "/data/miraldiNB/Michael/projects/goldStandard/Human/GS_FDR0p01/20260325/GS_5KB_TSS/sumRegion_reducePool/All/context_mean/global_max/GS_All_peakScore10_top50_sharedTF_Gene_withIDR.tsv", + "IDR" => "/data/miraldiNB/Michael/projects/goldStandard/Human/GS_IDR/20260321/GS_5KB_TSS/sumRegion_reducePool/All/context_mean/global_max/GS_All_peakScore10_top50_sharedTF_Gene_withFDR0p01.tsv" +) + +# Gold-standard files: name => file path +gsParam = OrderedDict( + "KO_GS" => "/data/miraldiNB/Michael/projects/GRN/hCD4T_Katko/dataBank/GS/KO_GS_50_Michael_autosomal.tsv", +) + +# Evaluation inputs +prTargGeneFile = "/data/miraldiNB/Michael/projects/GRN/hCD4T_Katko/dataBank/potTargRegs/Targs/all_targs_autosomal.txt" +gsRegsFile = "/data/miraldiNB/Katko/Projects/Barski_CD4_Multiome/Outs/Prior/SubsetPriors/all_TFs.txt" +rankColTrn = 3 # column in GRN file corresponding to interaction ranks/confidences +breakTies = true +auprLimit = 0.1 + +# Plot parameters +lineTypes = [] # e.g. ["-", "--", "-."] — one per dataset; [] uses defaults +lineWidths = [] # per dataset; [] uses defaults +lineColors = [] # per dataset; [] uses defaults +xLimitRecall = 0.1 +yStepSize = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] # one per gold standard +yScaleType = "linear" +yZoomPR = [[0.4, 0.9], [0.2, 0.9], [0.3, 0.9], [], [], [], [], [], [], [], [0.7, 0.9], [], [0.7, 0.9]] # one per gold standard, or [] +heightRatios = [0.5, 3.0] # height ratios for broken y-axis panels +isInside = false # legend inside the plot +plotAUPRflag = false # set to true to also generate AUPR bar plots +combinePlot = true # generate a combined PR curve per network/GS pair +doPerTF = true # compute per-TF PR metrics +tfList = [] # list of TF names for per-TF curves; [] skips Section 3 + +# ══════════════════════════════════════════════════════════════════════════════ +# EXECUTION — no edits needed below this line +# ══════════════════════════════════════════════════════════════════════════════ + +mkpath(dirOutPlot) + +# --- Helper: resolve per-GS plot parameters --- +function getPlotParams(i, gsName; figBaseName, yZoomPR, yStepSize) + saveNamePR = isempty(figBaseName) ? "$(gsName)" : "$(figBaseName)_$(gsName)" + currentYzoomPR = (length(yZoomPR) >= i && !isempty(yZoomPR[i])) ? yZoomPR[i] : Float64[] + currentYstepSize = (length(yStepSize) >= i && !isempty(yStepSize[i])) ? yStepSize[i] : nothing + return saveNamePR, currentYzoomPR, currentYstepSize +end + +# ── 1. Calculate PR/ROC metrics ─────────────────────────────────────────────── +@info "---- 1. Calculating Performance Metrics for the Networks -----" +prFilesByGS = OrderedDict{String, OrderedDict{String, Any}}() + +for (legendLabel, outNetFile) in outNetFiles + @info "Processing network" network=legendLabel file=outNetFile + filepath = dirname(outNetFile) + + for (gsName, gsFile) in gsParam + dirOut = isempty(gsRegsFile) ? joinpath(filepath, "PR_noPotRegs", gsName) : + joinpath(filepath, "PR_withPotRegs", gsName) + mkpath(dirOut) + @info "Using GS" gs=gsName saveDir=dirOut + + res = computePR(gsFile, outNetFile; + gsRegsFile = gsRegsFile, + targGeneFile = prTargGeneFile, + rankColTrn = rankColTrn, + breakTies = breakTies, + partialAUPRlimit = auprLimit, + doPerTF = doPerTF, + saveDir = dirOut) + + if !haskey(prFilesByGS, gsName) + prFilesByGS[gsName] = OrderedDict{String, Any}() + end + prFilesByGS[gsName][legendLabel] = haskey(res, :savedFile) ? res[:savedFile] : res + end +end + +# ── 2. Global PR curves ─────────────────────────────────────────────────────── +@info "---- 2. Generating Global PR Curves ----" +if combinePlot + for (i, (gsName, listFilePR)) in enumerate(prFilesByGS) + @info "Plotting PR curves" gs=gsName + + saveNamePR, currentYzoomPR, currentYstepSize = getPlotParams(i, gsName; + figBaseName = figBaseName, + yZoomPR = yZoomPR, + yStepSize = yStepSize) + + plotPRCurves(listFilePR, dirOutPlot, saveNamePR; + xLimitRecall = xLimitRecall, + yZoomPR = currentYzoomPR, + yStepSize = currentYstepSize, + yScale = yScaleType, + isInside = isInside, + lineColors = lineColors, + lineTypes = lineTypes, + lineWidths = lineWidths, + heightRatios = heightRatios, + mode = :global) + + if plotAUPRflag + singleGS = OrderedDict(gsName => listFilePR) + saveNameAUPR = isempty(figBaseName) ? "$(gsName)" : "$(figBaseName)_$(gsName)" + + for (figSize, saveLegend) in [((5, 4), true), ((1.5, 1.5), false)] + plotAUPR(singleGS, dirOutPlot; + saveName = saveNameAUPR, + metricType = "partial", + figSize = figSize, + axisTitleSize = 9, + tickLabelSize = 7, + legendFontSize = 9, + tickRotation = 45, + plotType = "bar", + saveLegend = saveLegend) + end + end + @info "Plots completed" gs=gsName + end +end + +# ── 3. Per-TF PR curves ─────────────────────────────────────────────────────── +if !isempty(tfList) + @info "----- 3. Generating Per-TF PR Curves -----" + for (i, (gsName, resultsDict)) in enumerate(prFilesByGS) + @info "Plotting per-TF PR curves" gs=gsName + + saveNamePR, currentYzoomPR, currentYstepSize = getPlotParams(i, gsName; + figBaseName = figBaseName, + yZoomPR = yZoomPR, + yStepSize = yStepSize) + saveNamePR = "perTF_$(saveNamePR)" + tfListPR = OrderedDict() + + resCache = Dict{String, Any}() + for (runName, source) in resultsDict + resCache[runName] = loadPRData(source; mode = :perTF) + end + + for (runName, res) in resCache + res === nothing && continue + tfIndex = Dict(tf => j for (j, tf) in enumerate(res[:gsRegs])) + for tf in tfList + idx = get(tfIndex, tf, nothing) + idx === nothing && continue + + label = + length(tfList) == 1 && length(resultsDict) > 1 ? runName : + length(resultsDict) == 1 && length(tfList) > 1 ? tf : + "$runName - $tf" + + tfListPR[label] = Dict( + :precisions => res[:precisions][idx], + :recalls => res[:recalls][idx], + :randPR => res[:randPR][idx] + ) + end + end + + plotPRCurves(tfListPR, dirOutPlot, saveNamePR; + xLimitRecall = xLimitRecall, + yZoomPR = currentYzoomPR, + yStepSize = currentYstepSize, + yScale = yScaleType, + isInside = isInside, + lineColors = lineColors, + lineTypes = lineTypes, + lineWidths = lineWidths) + end +end + +@info "Completed — plots generated for all gold standards" diff --git a/examples/run_pipeline.jl b/examples/run_pipeline.jl new file mode 100644 index 0000000..0ce0342 --- /dev/null +++ b/examples/run_pipeline.jl @@ -0,0 +1,160 @@ +# ============================================================================= +# run_pipeline.jl — Batch GRN inference wrapped in a callable function (public API) +# +# What: +# Wraps the full 6-step mLASSO-StARS pipeline in runInferelator(), which +# accepts all parameters as keyword arguments with sensible defaults. +# Use this for batch execution, scripted cluster jobs, or when running +# multiple parameter sweeps programmatically. +# All steps call high-level public API functions (no internal functions exposed). +# +# Required inputs: +# geneExprFile — gene expression matrix (.txt tab-delimited or .arrow) +# targFile — target gene list (.txt, one gene per line) +# regFile — potential regulator (TF) list (.txt, one TF per line) +# priorFile — prior network matrix (.tsv, sparse TF × gene) +# priorFilePenalties — prior(s) used to set LASSO penalties (same or different) +# outputDir — root directory for all outputs +# +# Expected outputs (written under outputDir//): +# TFA/edges.tsv — GRN inferred with TFA predictors +# TFmRNA/edges.tsv — GRN inferred with TF mRNA predictors +# Combined/combined_*.tsv — consensus network (max/mean/min aggregation) +# Combined/TFA/ — refined TFA network re-estimated from consensus prior +# +# Usage: +# julia examples/run_pipeline.jl +# or call runInferelator() from another script after: using InferelatorJL +# +# Compare with: +# interactive_pipeline.jl — same steps run interactively (not wrapped) +# run_pipeline_dev.jl — same function using module-qualified internal calls +# +# Installation: +# pkg> dev /path/to/InferelatorJL # local development +# pkg> add "https://github.com/org/InferelatorJL.jl" # published release +# Tip: load Revise before this file to pick up source edits without restarting: +# using Revise; using InferelatorJL +# ============================================================================= + +using InferelatorJL + +# ============================================================================= +# Pipeline function +# ============================================================================= + +function runInferelator(; + geneExprFile::String, + targFile::String, + regFile::String, + priorFile::String, + priorFilePenalties::Vector{String}, + tfaGeneFile::String = "", + outputDir::String, + totSS::Int = 80, + bstarsTotSS::Int = 5, + subsampleFrac::Float64 = 0.68, + minLambda::Float64 = 0.01, + maxLambda::Float64 = 0.5, + totLambdasBstars::Int = 20, + totLambdas::Int = 40, + targetInstability::Float64 = 0.05, + meanEdgesPerGene::Int = 20, + correlationWeight::Int = 1, + minTargets::Int = 3, + edgeSS::Int = 0, + lambdaBias::Vector{Float64} = [0.5], + instabilityLevel::String = "Network", # "Network" or "Gene" + useMeanEdgesPerGeneMode::Bool = true, + combineOpt::String = "max", # "max", "mean", or "min" + zScoreTFA::Bool = true, + zScoreLASSO::Bool = true +) + # Build output directory name (encodes key run parameters) + subsamplePct = subsampleFrac * 100 + subsampleStr = isinteger(subsamplePct) ? string(Int(subsamplePct)) : replace(string(subsamplePct), "." => "p") + lambdaStr = join(replace.(string.(lambdaBias), "." => "p"), "_") + networkBaseName = lowercase(instabilityLevel) * "Lambda" * lambdaStr * "_" * string(totSS) * "totSS_" * + string(meanEdgesPerGene) * "tfsPerGene_" * "subsamplePCT" * subsampleStr + dirOut = joinpath(outputDir, networkBaseName) + mkpath(dirOut) + + @info "Starting pipeline" outputDir=dirOut geneExprFile priorFile lambdaBias subsampleFrac + + # Step 1 — Load and filter expression data + data = loadData(geneExprFile, targFile, regFile; + tfaGeneFile = tfaGeneFile, + epsilon = 0.01) + + # Steps 2 + 3 — Merge degenerate TFs, process prior, estimate TFA + priorData, mergedTFs = loadPrior(data, priorFile; minTargets = minTargets) + + estimateTFA(priorData, data; + edgeSS = edgeSS, + zScoreTFA = zScoreTFA, + outputDir = dirOut) + + # Step 4 — Build GRN for each predictor mode + for (tfaMode, modeLabel) in [(true, "TFA"), (false, "TFmRNA")] + instabilitiesDir = joinpath(dirOut, modeLabel) + mkpath(instabilitiesDir) + + @info "Building network" mode=modeLabel + + buildNetwork(data, priorData; + tfaMode = tfaMode, + priorFilePenalties = priorFilePenalties, + lambdaBias = lambdaBias, + totSS = totSS, + bstarsTotSS = bstarsTotSS, + subsampleFrac = subsampleFrac, + minLambda = minLambda, + maxLambda = maxLambda, + totLambdasBstars = totLambdasBstars, + totLambdas = totLambdas, + targetInstability = targetInstability, + meanEdgesPerGene = meanEdgesPerGene, + correlationWeight = correlationWeight, + instabilityLevel = instabilityLevel, + useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode, + zScoreLASSO = zScoreLASSO, + outputDir = instabilitiesDir) + end + + # Step 5 — Aggregate TFA + mRNA networks into a consensus network + combinedNetDir = joinpath(dirOut, "Combined") + aggregateNetworks( + [joinpath(dirOut, "TFA", "edges.tsv"), + joinpath(dirOut, "TFmRNA", "edges.tsv")]; + method = Symbol(combineOpt), + meanEdgesPerGene = meanEdgesPerGene, + useMeanEdgesPerGene = useMeanEdgesPerGeneMode, + outputDir = combinedNetDir) + + # Step 6 — Re-estimate TFA using the consensus network as a refined prior + netsCombinedSparse = joinpath(combinedNetDir, "combined_" * combineOpt * "_sp.tsv") + refineTFA(netsCombinedSparse, data, mergedTFs; + tfaGeneFile = tfaGeneFile, + edgeSS = edgeSS, + minTargets = minTargets, + zScoreTFA = zScoreTFA, + exprFile = geneExprFile, + targFile = targFile, + regFile = regFile, + outputDir = combinedNetDir) + + @info "Pipeline complete" outputDir=dirOut +end + +# ============================================================================= +# Run — replace paths with your own data +# ============================================================================= + +runInferelator( + geneExprFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/pseudobulk/pseudobulk_scrna/CellType/Age/Factor1/min0.25M/counts_Tfh10_AgeCellType_pseudobulk_scrna_vst_batch_NoState.txt", + targFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/target_genes/gene_targ_Tfh10_SigPct5Log2FC0p58FDR5.txt", + regFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/pot_regs/TF_Tfh10_SigPct5Log2FC0p58FDR5_final.txt", + priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv", + priorFilePenalties = ["/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"], + outputDir = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/Inferelator/test" +) diff --git a/examples/run_pipeline_dev.jl b/examples/run_pipeline_dev.jl new file mode 100644 index 0000000..af51089 --- /dev/null +++ b/examples/run_pipeline_dev.jl @@ -0,0 +1,184 @@ +# ============================================================================= +# run_pipeline_dev.jl — Batch GRN inference via module-qualified internal calls +# +# What: +# Identical pipeline to run_pipeline.jl but calls internal functions directly +# using the InferelatorJL. module prefix. Use this when you need finer control +# over individual pipeline steps while still running in batch/function mode. +# All internal functions remain accessible via module-qualified calls even +# though they are not exported from the package. +# +# Required inputs: +# geneExprFile — gene expression matrix (.txt tab-delimited or .arrow) +# targFile — target gene list (.txt, one gene per line) +# regFile — potential regulator (TF) list (.txt, one TF per line) +# priorFile — prior network matrix (.tsv, sparse TF × gene) +# priorFilePenalties — prior(s) used to set LASSO penalties (same or different) +# outputDir — root directory for all outputs +# +# Expected outputs (written under outputDir//): +# TFA/edges.tsv — GRN inferred with TFA predictors +# TFmRNA/edges.tsv — GRN inferred with TF mRNA predictors +# Combined/combined_*.tsv — consensus network (max/mean/min aggregation) +# Combined/TFA/ — refined TFA network re-estimated from consensus prior +# +# Usage: +# julia examples/run_pipeline_dev.jl +# or call runInferelator() from another script after: using InferelatorJL +# +# Compare with: +# interactive_pipeline_dev.jl — same internal calls run interactively +# run_pipeline.jl — same function using public API +# +# Installation: +# pkg> dev /path/to/InferelatorJL # local development +# pkg> add "https://github.com/org/InferelatorJL.jl" # published release +# Tip: load Revise before this file to pick up source edits without restarting: +# using Revise; using InferelatorJL +# ============================================================================= + +using InferelatorJL + +# ============================================================================= +# Pipeline function +# ============================================================================= + +function runInferelator(; + geneExprFile::String, + targFile::String, + regFile::String, + priorFile::String, + priorFilePenalties::Vector{String}, + tfaGeneFile::String = "", + outputDir::String, + tfaOptions::Vector{String} = ["", "TFmRNA"], # "" → TFA, "TFmRNA" → mRNA + totSS::Int = 80, + bstarsTotSS::Int = 5, + subsampleFrac::Float64 = 0.68, + minLambda::Float64 = 0.01, + maxLambda::Float64 = 0.5, + totLambdasBstars::Int = 20, + totLambdas::Int = 40, + targetInstability::Float64 = 0.05, + meanEdgesPerGene::Int = 20, + correlationWeight::Int = 1, + minTargets::Int = 3, + edgeSS::Int = 0, + lambdaBias::Vector{Float64} = [0.5], + instabilityLevel::String = "Network", # "Network" or "Gene" + useMeanEdgesPerGeneMode::Bool = true, + combineOpt::String = "max", # "max", "mean", or "min" + zScoreTFA::Bool = true, + zScoreLASSO::Bool = true +) + # Build output directory name (encodes key run parameters) + subsamplePct = subsampleFrac * 100 + subsampleStr = isinteger(subsamplePct) ? string(Int(subsamplePct)) : replace(string(subsamplePct), "." => "p") + lambdaStr = join(replace.(string.(lambdaBias), "." => "p"), "_") + networkBaseName = lowercase(instabilityLevel) * "Lambda" * lambdaStr * "_" * string(totSS) * "totSS_" * + string(meanEdgesPerGene) * "tfsPerGene_" * "subsamplePCT" * subsampleStr + dirOut = joinpath(outputDir, networkBaseName) + mkpath(dirOut) + + @info "Starting pipeline" outputDir=dirOut geneExprFile priorFile lambdaBias subsampleFrac + + # Step 1 — Load and filter expression data + data = GeneExpressionData() + InferelatorJL.loadExpressionData!(data, geneExprFile) + InferelatorJL.loadAndFilterTargetGenes!(data, targFile; epsilon = 0.01) + InferelatorJL.loadPotentialRegulators!(data, regFile) + InferelatorJL.processTFAGenes!(data, tfaGeneFile; outputDir = dirOut) + + # Step 2 — Merge degenerate TFs + mergedTFsData = mergedTFsResult() + InferelatorJL.mergeDegenerateTFs(mergedTFsData, priorFile; fileFormat = 2) + + # Step 3 — Process prior and estimate TFA + tfaData = PriorTFAData() + InferelatorJL.processPriorFile!(tfaData, data, priorFile; + mergedTFsData = mergedTFsData, + minTargets = minTargets) + InferelatorJL.calculateTFA!(tfaData, data; + edgeSS = edgeSS, + zTarget = zScoreTFA, + outputDir = dirOut) + + # Step 4 — Build GRN for each predictor mode + for tfaOpt in tfaOptions + instabilitiesDir = tfaOpt == "" ? joinpath(dirOut, "TFA") : joinpath(dirOut, "TFmRNA") + mkpath(instabilitiesDir) + + @info "Building network" tfaOpt=(isempty(tfaOpt) ? "TFA" : tfaOpt) + + grnData = GrnData() + InferelatorJL.preparePredictorMat!(grnData, data, tfaData; tfaOpt = tfaOpt) + InferelatorJL.preparePenaltyMatrix!(data, grnData; + priorFilePenalties = priorFilePenalties, + lambdaBias = lambdaBias, + tfaOpt = tfaOpt) + InferelatorJL.constructSubsamples(data, grnData; totSS = bstarsTotSS, subsampleFrac = subsampleFrac) + InferelatorJL.bstarsWarmStart(data, tfaData, grnData; + minLambda = minLambda, + maxLambda = maxLambda, + totLambdasBstars = totLambdasBstars, + targetInstability = targetInstability, + zTarget = zScoreLASSO) + InferelatorJL.constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = subsampleFrac) + InferelatorJL.bstartsEstimateInstability(grnData; + totLambdas = totLambdas, + instabilityLevel = instabilityLevel, + zTarget = zScoreLASSO, + outputDir = instabilitiesDir) + + buildGrn = BuildGrn() + InferelatorJL.chooseLambda!(grnData, buildGrn; + instabilityLevel = instabilityLevel, + targetInstability = targetInstability) + InferelatorJL.rankEdges!(data, tfaData, grnData, buildGrn; + useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode, + meanEdgesPerGene = meanEdgesPerGene, + correlationWeight = correlationWeight, + outputDir = instabilitiesDir) + writeNetworkTable!(buildGrn; outputDir = instabilitiesDir) + end + + # Step 5 — Aggregate TFA + mRNA networks into a consensus network + combinedNetDir = joinpath(dirOut, "Combined") + nets2combine = [ + joinpath(dirOut, "TFA", "edges.tsv"), + joinpath(dirOut, "TFmRNA", "edges.tsv") + ] + InferelatorJL.aggregateNetworks(nets2combine; + method = Symbol(combineOpt), + meanEdgesPerGene = meanEdgesPerGene, + useMeanEdgesPerGene = useMeanEdgesPerGeneMode, + outputDir = combinedNetDir) + + # Step 6 — Re-estimate TFA using the consensus network as a refined prior + netsCombinedSparse = joinpath(combinedNetDir, "combined_" * combineOpt * "_sp.tsv") + InferelatorJL.refineTFA(data, mergedTFsData; + priorFile = netsCombinedSparse, + tfaGeneFile = tfaGeneFile, + edgeSS = edgeSS, + minTargets = minTargets, + zTarget = zScoreTFA, + geneExprFile = geneExprFile, + targFile = targFile, + regFile = regFile, + outputDir = combinedNetDir) + + @info "Pipeline complete" outputDir=dirOut +end + +# ============================================================================= +# Run — replace paths with your own data +# ============================================================================= + +runInferelator( + geneExprFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/pseudobulk/pseudobulk_scrna/CellType/Age/Factor1/min0.25M/counts_Tfh10_AgeCellType_pseudobulk_scrna_vst_batch_NoState.txt", + targFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/target_genes/gene_targ_Tfh10_SigPct5Log2FC0p58FDR5.txt", + regFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/pot_regs/TF_Tfh10_SigPct5Log2FC0p58FDR5_final.txt", + priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv", + priorFilePenalties = ["/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"], + outputDir = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/Inferelator/test" +) diff --git a/examples/utilityExamples.jl b/examples/utilityExamples.jl new file mode 100644 index 0000000..c88413d --- /dev/null +++ b/examples/utilityExamples.jl @@ -0,0 +1,245 @@ +# ============================================================================= +# utilityExamples.jl — Self-contained examples for InferelatorJL utility functions +# +# What: +# Demonstrates all exported utility functions using small synthetic datasets. +# No real input files are required — all data is generated inline. +# Sections: +# 1. DataUtils — reshape, normalize, binarize, merge prior matrices +# 2. PartialCorrelation — precision-matrix and regression-based methods +# 3. NetworkIO — write edge tables and save data structs +# 4. aggregateNetworks — combine multiple GRN edge files into one network +# +# Required inputs: None (all data is synthetic) +# +# Expected outputs: Printed results to the REPL / stdout. +# Sections 3 and 4 also write temporary files to tempdir(). +# +# Usage: +# julia examples/utilityExamples.jl +# or paste individual sections into a Julia REPL +# +# Installation: +# pkg> dev /path/to/InferelatorJL # local development +# pkg> add "https://github.com/org/InferelatorJL.jl" # published release +# Tip: load Revise before this file to pick up source edits without restarting: +# using Revise; using InferelatorJL +# ============================================================================= + +using InferelatorJL +using DataFrames, CSV, Random, Statistics, LinearAlgebra + + +# ============================================================================= +# 1. DataUtils +# ============================================================================= + +println("\n===== 1. DataUtils =====\n") + +# ── convertToLong / convertToWide ──────────────────────────────────────────── +# Wide prior matrix: genes (rows) × TFs (columns) +widePrior = DataFrame( + Gene = ["Gata3", "Foxp3", "Tbx21", "Rorc"], + TF_A = [1.0, 0.0, 0.5, 0.0], + TF_B = [0.0, 1.0, 0.0, 1.0], + TF_C = [0.8, 0.0, 0.0, 0.6] +) + +# Convert to long (TF, Gene, Weight) +longPrior = convertToLong(widePrior) +println("Wide → Long (first 5 rows):") +println(first(longPrior, 5)) + +# Round-trip back to wide +widePrior2 = convertToWide(longPrior; indices = (2, 1, 3)) +println("\nLong → Wide (restored):") +println(widePrior2) + + +# ── frobeniusNormalize ──────────────────────────────────────────────────────── +# Normalize each column of the prior matrix so each column has unit L2 norm. +# Typical use: normalize prior before feeding into penalty matrix construction. +priorNorm = frobeniusNormalize(widePrior, :column) +println("\nColumn-normalized prior (each TF column has ||·||₂ = 1):") +println(priorNorm) + +# Verify norms are 1 +details, nTrue, nFalse = check_column_norms(priorNorm; atol = 1e-8) +println("Columns with unit norm: $nTrue / $(nTrue + nFalse)") + +# Row normalization (normalize each gene's regulatory weight vector) +priorNormRow = frobeniusNormalize(widePrior, :row) +println("\nRow-normalized prior:") +println(priorNormRow) + + +# ── binarizeNumeric! ────────────────────────────────────────────────────────── +# Replace all non-zero values with 1 (convert continuous prior → binary prior). +priorBinary = deepcopy(widePrior) +binarizeNumeric!(priorBinary) +println("\nBinarized prior:") +println(priorBinary) + + +# ── mergeDFs ────────────────────────────────────────────────────────────────── +# Merge two prior DataFrames from different assays (ATAC + ChIP) by summing. +# Useful when combining evidence from multiple prior sources. +prior_atac = DataFrame(Gene = ["Gata3","Foxp3","Tbx21"], TF_A = [1.0,0.0,0.5], TF_B = [0.0,1.0,0.0]) +prior_chip = DataFrame(Gene = ["Foxp3","Tbx21","Rorc"], TF_A = [0.0,0.3,0.0], TF_C = [1.0,0.0,1.0]) + +mergedPrior = mergeDFs([prior_atac, prior_chip], :Gene, "sum") +println("\nMerged prior (ATAC + ChIP, sum):") +println(mergedPrior) + +mergedPriorAvg = mergeDFs([prior_atac, prior_chip], :Gene, "avg") +println("\nMerged prior (ATAC + ChIP, average):") +println(mergedPriorAvg) + + +# ── completeDF ──────────────────────────────────────────────────────────────── +# Align a DataFrame to a fixed set of row IDs and column names (fills missing → 0). +allGenes = ["Gata3", "Foxp3", "Tbx21", "Rorc", "Il2ra"] +allTFs = [:TF_A, :TF_B, :TF_C, :TF_D] + +completedPrior = completeDF(prior_atac, :Gene, allGenes, allTFs) +println("\ncompleteDF — prior_atac aligned to full gene/TF universe:") +println(completedPrior) + + +# ── writeTSVWithEmptyFirstHeader ────────────────────────────────────────────── +# Write a DataFrame to TSV with the first (row-label) column header left blank. +# This is the sparse prior format expected by processPriorFile!. +tmpDir = tempdir() +sparseFile = joinpath(tmpDir, "prior_example_sp.tsv") +writeTSVWithEmptyFirstHeader(priorBinary, sparseFile; delim = '\t') +println("\nSparse prior written to: $sparseFile") +println(read(sparseFile, String)[1:min(200, filesize(sparseFile))]) + + +# ============================================================================= +# 2. PartialCorrelation +# ============================================================================= + +println("\n===== 2. PartialCorrelation =====\n") + +# Synthetic expression matrix: 50 cells × 6 genes +Random.seed!(42) +X = randn(50, 6) +# Add a latent signal so genes 1 and 3 are partially correlated +latent = randn(50) +X[:, 1] .+= 0.8 .* latent +X[:, 3] .+= 0.7 .* latent + +# ── Method 1: Precision matrix ──────────────────────────────────────────────── +# partialCorrelationMat returns the full p×p partial correlation matrix. +# Note: accessed via module prefix since it is an internal function. +Pfull = InferelatorJL.partialCorrelationMat(X; epsilon = 1e-6, first_vs_all = false) +println("Full partial correlation matrix (6×6):") +display(round.(Pfull, digits = 3)) + +# first_vs_all = true: only partial correlations of column 1 vs all others. +# This is how it is used inside rankEdges! (target gene vs. TF predictors). +P1 = InferelatorJL.partialCorrelationMat(X; epsilon = 1e-6, first_vs_all = true) +println("\nPartial correlations of gene 1 vs. all others (1×6):") +println(round.(P1, digits = 3)) + +# ── Method 2: Regression residuals ─────────────────────────────────────────── +# partialCorrReg computes the same quantity via OLS residuals. +# Slower than the precision-matrix method; use for small matrices. +Preg = InferelatorJL.partialCorrReg(X; first_vs_all = false) +println("\nPartial correlation (regression method, 6×6):") +display(round.(Preg, digits = 3)) + +println("\nDifference between methods (should be ~0):") +println("Max abs diff: ", round(maximum(abs.(Pfull[2:end, 2:end] .- Preg[2:end, 2:end])), digits = 6)) + + +# ============================================================================= +# 3. NetworkIO +# ============================================================================= + +println("\n===== 3. NetworkIO =====\n") + +# ── writeNetworkTable! ──────────────────────────────────────────────────────── +# Populate a minimal BuildGrn with synthetic edge data and write TSV outputs. +buildGrn = BuildGrn() +buildGrn.regs = ["TF_A", "TF_B", "TF_A", "TF_C"] +buildGrn.targs = ["Gata3", "Foxp3", "Tbx21", "Rorc"] +buildGrn.signedQuantile = [0.92, -0.75, 0.60, 0.85] +buildGrn.rankings = [0.88, 0.72, 0.55, 0.80] +buildGrn.partialCorrelation = [0.45, -0.38, 0.30, 0.41] +buildGrn.inPrior = ["Yes", "Yes", "No", "Yes"] +buildGrn.networkMat = hcat(buildGrn.regs, buildGrn.targs, + buildGrn.signedQuantile, buildGrn.rankings, + buildGrn.partialCorrelation, buildGrn.inPrior) +buildGrn.networkMatSubset = buildGrn.networkMat[1:3, :] # top 3 edges as subset + +netDir = joinpath(tmpDir, "network_example") +mkpath(netDir) +writeNetworkTable!(buildGrn; outputDir = netDir) +println("edges.tsv written to: $netDir") +println(read(joinpath(netDir, "edges.tsv"), String)) + + +# ── saveData ────────────────────────────────────────────────────────────────── +# Save all four core structs to a .jld2 file for checkpointing. +# Reload in a fresh session with: +# using InferelatorJL, JLD2 +# @load joinpath(netDir, "checkpoint.jld2") expressionData tfaData grnData buildGrn +# +# (Skipped here since data/tfaData/grnData require real inputs) + + +# ============================================================================= +# 4. aggregateNetworks +# ============================================================================= + +println("\n===== 4. aggregateNetworks =====\n") + +# Write two synthetic edge files (TFA and mRNA modes) then combine them. +tfaEdges = DataFrame( + TF = ["TF_A", "TF_B", "TF_C", "TF_A"], + Gene = ["Gata3", "Foxp3", "Tbx21", "Rorc"], + signedQuantile = [0.92, -0.75, 0.60, 0.85], + Stability = [0.88, 0.72, 0.55, 0.80], + Correlation = [0.45, -0.38, 0.30, 0.41], + inPrior = ["Yes", "Yes", "No", "Yes"] +) + +mrnaEdges = DataFrame( + TF = ["TF_A", "TF_B", "TF_C", "TF_D"], + Gene = ["Gata3", "Foxp3", "Tbx21", "Rorc"], + signedQuantile = [0.80, -0.65, 0.70, 0.55], + Stability = [0.75, 0.60, 0.68, 0.50], + Correlation = [0.40, -0.32, 0.35, 0.28], + inPrior = ["Yes", "Yes", "No", "No"] +) + +tfaFile = joinpath(tmpDir, "TFA_edges.tsv") +mrnaFile = joinpath(tmpDir, "mRNA_edges.tsv") +CSV.write(tfaFile, tfaEdges; delim = '\t') +CSV.write(mrnaFile, mrnaEdges; delim = '\t') + +combDir = joinpath(tmpDir, "Combined") + +# Combine using max stability per (TF, Gene) pair +combinedMax = aggregateNetworks( + [tfaFile, mrnaFile]; + method = :max, + meanEdgesPerGene = 3, + useMeanEdgesPerGene = true, + outputDir = combDir +) +println("Combined network (:max strategy), $(nrow(combinedMax)) edges:") +println(combinedMax) + +# Combine using mean stability +combinedMean = aggregateNetworks( + [tfaFile, mrnaFile]; + method = :mean, + meanEdgesPerGene = 3, + useMeanEdgesPerGene = true, + outputDir = combDir +) +println("\nCombined network (:mean strategy), $(nrow(combinedMean)) edges:") +println(combinedMean) diff --git a/experimental/.DS_Store b/experimental/.DS_Store new file mode 100644 index 0000000..99c7495 Binary files /dev/null and b/experimental/.DS_Store differ diff --git a/experimental/MTL/MultitaskGRN_equations.pptx b/experimental/MTL/MultitaskGRN_equations.pptx new file mode 100755 index 0000000..8f775a5 Binary files /dev/null and b/experimental/MTL/MultitaskGRN_equations.pptx differ diff --git a/experimental/MTL/MultitaskInferelator.jl b/experimental/MTL/MultitaskInferelator.jl new file mode 100755 index 0000000..aa335e4 --- /dev/null +++ b/experimental/MTL/MultitaskInferelator.jl @@ -0,0 +1,28 @@ +module MultitaskInferelator + + # ── Unchanged modules from original Inferelator ────────────────────────── + include("core/geneExpression.jl") # Data module [UNCHANGED] + include("core/mergeDegenerateTFs.jl") # MergeDegenerate [UNCHANGED] + include("core/networkIO.jl") # NetworkIO [UNCHANGED] + include("utils/dataUtils.jl") # DataUtils [UNCHANGED] + include("utils/utilsGRN.jl") # firstNByGroup [UNCHANGED] + + # ── Modified modules ────────────────────────────────────────────────────── + include("core/priorTFA.jl") # PriorTFA [MODIFIED - per-task TFA] + include("core/GRN.jl") # GRN structs [MODIFIED - new structs] + + # ── New multitask modules ───────────────────────────────────────────────── + include("multitask/multitaskData.jl") # MT data structs [NEW] + include("multitask/taskSimilarity.jl") # Study graph [NEW] + include("multitask/admm.jl") # ADMM solver [NEW] + include("multitask/prepareMultitask.jl")# MT prepare fns [NEW] + include("multitask/buildMultitask.jl") # MT build/rank fns [NEW] + include("multitask/combineMultitask.jl")# MT combine fns [NEW] + + # ── Evaluation (unchanged) ──────────────────────────────────────────────── + include("evaluation/Metric.jl") + include("evaluation/calcPRinfTRNs.jl") + include("evaluation/plotSingleUtils.jl") + include("evaluation/plotBatchMetrics.jl") + +end diff --git a/experimental/MTL/admm 2.jl b/experimental/MTL/admm 2.jl new file mode 100755 index 0000000..a5cb45f --- /dev/null +++ b/experimental/MTL/admm 2.jl @@ -0,0 +1,380 @@ +""" +admm.jl [UPDATED] + +Changes vs original: +───────────────────────────────────────────────────────────────────── +NEW elasticNetAlpha parameter throughout (default 1.0 = pure LASSO) +NEW selectLambdas() — top-level dispatcher for all three strategies +NEW selectLambdas_fixedRatio — Option 1: lambda_f = fusionRatio * lambda_s +NEW selectLambdas_ebic — Option 2: EBIC grid search +NEW selectLambdas_bstars2d — Option 3: 2D bStARS instability surface +""" + +@inline function softThreshold(x::Float64, threshold::Float64) + return sign(x) * max(abs(x) - threshold, 0.0) +end + + +# ──────────────────────────────────────────────────────────────────────────── +# Core ADMM solver [UPDATED: elasticNetAlpha added] +# ──────────────────────────────────────────────────────────────────────────── + +""" + admm_fused_lasso(predictorMats, responseMats, penaltyFactors, + taskSimilarity, lambda_s, lambda_f; + rho, maxIter, tol, elasticNetAlpha) + +Solve the graph-fused multitask LASSO for a single target gene. + +elasticNetAlpha: L1/L2 mixing parameter passed to GLMNet alpha argument. + 1.0 (default) = pure LASSO + 0.5 = equal L1+L2 — recommended when demerged TFs are present + 0.0 = pure ridge +""" +function admm_fused_lasso( + predictorMats::Vector{Matrix{Float64}}, + responseMats::Vector{Vector{Float64}}, + penaltyFactors::Vector{Vector{Float64}}, + taskSimilarity::Matrix{Float64}, + lambda_s::Float64, + lambda_f::Float64; + rho::Float64 = 1.0, + maxIter::Int = 100, + tol::Float64 = 1e-4, + elasticNetAlpha::Float64 = 1.0 + ) + + nTasks = length(predictorMats) + nTFs = size(predictorMats[1], 1) + edges = [(d, dp) for d in 1:nTasks for dp in (d+1):nTasks + if taskSimilarity[d, dp] > 0] + nEdges = length(edges) + + W = zeros(Float64, nTFs, nTasks) + Z = zeros(Float64, nTFs, nEdges) + U = zeros(Float64, nTFs, nEdges) + + for iter in 1:maxIter + W_prev = copy(W) + + # W-update: per-task elastic net with ADMM augmentation + for d in 1:nTasks + dEdges = [(e, d1 == d ? 1.0 : -1.0) + for (e, (d1, d2)) in enumerate(edges) + if d1 == d || d2 == d] + nSamps = size(predictorMats[d], 2) + A = transpose(predictorMats[d]) + x = responseMats[d] + + if isempty(dEdges) + lsoln = glmnet(A, x, + penalty_factor = penaltyFactors[d], + lambda = [lambda_s], + alpha = elasticNetAlpha) + W[:, d] = vec(lsoln.betas) + else + augA = copy(A) + augX = copy(x) + for (e, sgn) in dEdges + target = sgn > 0 ? Z[:, e] - U[:, e] : -Z[:, e] - U[:, e] + sqrtRho = sqrt(rho) + augA = vcat(augA, sqrtRho * I(nTFs)) + augX = vcat(augX, sqrtRho * target) + end + lsoln = glmnet(augA, augX, + penalty_factor = penaltyFactors[d], + lambda = [lambda_s / (nSamps + nTFs * length(dEdges))], + alpha = elasticNetAlpha) + W[:, d] = vec(lsoln.betas) + end + end + + # Z-update: fusion proximal operator + for (e, (d, dp)) in enumerate(edges) + diff = W[:, d] - W[:, dp] + U[:, e] + threshold = lambda_f * taskSimilarity[d, dp] / rho + Z[:, e] = softThreshold.(diff, threshold) + end + + # U-update: dual variable + for (e, (d, dp)) in enumerate(edges) + U[:, e] += W[:, d] - W[:, dp] - Z[:, e] + end + + norm(W - W_prev) < tol && break + end + + return W +end + + +function admmWarmStart( + predictorMats::Vector{Matrix{Float64}}, + responseMats::Vector{Vector{Float64}}, + penaltyFactors::Vector{Vector{Float64}}, + taskSimilarity::Matrix{Float64}, + lambdaRange_s::Vector{Float64}, + lambda_f::Float64; + elasticNetAlpha::Float64 = 1.0, + kwargs... + ) + nTFs = size(predictorMats[1], 1) + nTasks = length(predictorMats) + nLambda = length(lambdaRange_s) + betasByLambda = Array{Float64, 3}(undef, nTFs, nTasks, nLambda) + + for (li, ls) in enumerate(lambdaRange_s) + betasByLambda[:, :, li] = admm_fused_lasso( + predictorMats, responseMats, penaltyFactors, + taskSimilarity, ls, lambda_f; + elasticNetAlpha = elasticNetAlpha, kwargs... + ) + end + return betasByLambda +end + + +# ──────────────────────────────────────────────────────────────────────────── +# Lambda selection dispatcher [NEW] +# ──────────────────────────────────────────────────────────────────────────── + +""" + selectLambdas(mtGrnData, mtData, res; lambdaOpt, ...) + +Top-level dispatcher for joint (lambda_s, lambda_f) selection. + +lambdaOpt: + :fixed_ratio — lambda_f = fusionRatio * lambda_s (default, zero added cost) + :ebic — EBIC grid search over 2D (lambda_s, lambda_f) + :bstars_2d — 2D bStARS instability surface + EBIC tiebreaker + +Returns (lambda_s, lambda_f) for target gene `res`. +""" +function selectLambdas( + mtGrnData::MultitaskGrnData, + mtData::MultitaskExpressionData, + res::Int; + lambdaOpt::Symbol = :fixed_ratio, + fusionRatio::Float64 = 0.1, + ebicGamma::Float64 = 1.0, + gridSize::Int = 10, + refinementSize::Int = 10, + targetInstability::Float64 = 0.05, + elasticNetAlpha::Float64 = 1.0, + totSS::Int = 20, + zTarget::Bool = false + ) + + if lambdaOpt == :fixed_ratio + return _selectLambdas_fixedRatio(mtGrnData, res; fusionRatio) + elseif lambdaOpt == :ebic + return _selectLambdas_ebic(mtGrnData, mtData, res; + ebicGamma, gridSize, elasticNetAlpha, totSS, zTarget) + elseif lambdaOpt == :bstars_2d + return _selectLambdas_bstars2d(mtGrnData, mtData, res; + targetInstability, ebicGamma, + coarseSize = gridSize, + fineSize = refinementSize, + elasticNetAlpha, totSS, zTarget) + else + error("Unknown lambdaOpt: $lambdaOpt. Choose :fixed_ratio, :ebic, or :bstars_2d") + end +end + + +# ── Option 1: Fixed ratio ──────────────────────────────────────────────────── +function _selectLambdas_fixedRatio(mtGrnData::MultitaskGrnData, res::Int; + fusionRatio::Float64 = 0.1) + refGrn = mtGrnData.taskGrnData[1] + lambdaRangeGene = refGrn.lambdaRangeGene[res] + currInstabs = refGrn.geneInstabilities[res, :] + devs = abs.(currInstabs .- 0.05) + globalMin = minimum(devs) + minInd = findall(x -> x == globalMin, devs)[end] + lambda_s = lambdaRangeGene[minInd] + lambda_f = fusionRatio * lambda_s + return lambda_s, lambda_f +end + + +# ── Option 2: EBIC grid search ─────────────────────────────────────────────── +function _selectLambdas_ebic( + mtGrnData::MultitaskGrnData, + mtData::MultitaskExpressionData, + res::Int; + ebicGamma::Float64 = 1.0, + gridSize::Int = 10, + elasticNetAlpha::Float64 = 1.0, + totSS::Int = 20, + zTarget::Bool = false + ) + nTasks = length(mtData.tasks) + responsePredInds = [findall(x -> x != Inf, mtGrnData.taskGrnData[d].penaltyMat[res, :]) + for d in 1:nTasks] + taskPredMats = [transpose(mtGrnData.taskGrnData[d].predictorMat[responsePredInds[d], :]) + for d in 1:nTasks] + taskRespVecs = [vec(mtGrnData.taskGrnData[d].responseMat[res, :]) for d in 1:nTasks] + taskPenalties = [mtGrnData.taskGrnData[d].penaltyMat[res, responsePredInds[d]] + for d in 1:nTasks] + + lambda_s_grid = exp.(range(log(0.01), log(10.0), length = gridSize)) + alpha_grid = range(0.05, 1.0, length = gridSize) + + # Traverse from high to low (lambda_s + lambda_f) for warm starting + grid_pairs = [(ls, a * ls) + for ls in reverse(lambda_s_grid) + for a in reverse(alpha_grid) + if 0.5 < ls / (a * ls) < 2.0] + + bestEBIC = Inf + best_ls = grid_pairs[1][1] + best_lf = grid_pairs[1][2] + + for (ls, lf) in grid_pairs + W = admm_fused_lasso(taskPredMats, taskRespVecs, taskPenalties, + mtGrnData.taskGraph, ls, lf; + elasticNetAlpha = elasticNetAlpha) + ebic = _ebic(W, taskPredMats, taskRespVecs, nTasks, ebicGamma) + if ebic < bestEBIC + bestEBIC = ebic; best_ls = ls; best_lf = lf + end + end + return best_ls, best_lf +end + + +# ── Option 3: 2D bStARS ───────────────────────────────────────────────────── +function _selectLambdas_bstars2d( + mtGrnData::MultitaskGrnData, + mtData::MultitaskExpressionData, + res::Int; + targetInstability::Float64 = 0.05, + ebicGamma::Float64 = 1.0, + coarseSize::Int = 5, + fineSize::Int = 10, + elasticNetAlpha::Float64 = 1.0, + totSS::Int = 20, + zTarget::Bool = false + ) + nTasks = length(mtData.tasks) + responsePredInds = [findall(x -> x != Inf, mtGrnData.taskGrnData[d].penaltyMat[res, :]) + for d in 1:nTasks] + + # Stage 1: coarse grid + ls_c = exp.(range(log(0.01), log(5.0), length = coarseSize)) + lf_c = exp.(range(log(0.001), log(2.0), length = coarseSize)) + D_c = _instabGrid(mtGrnData, mtData, res, ls_c, lf_c, + responsePredInds, nTasks, totSS, zTarget, elasticNetAlpha) + + # Identify region near contour + near = findall(x -> abs(x - targetInstability) < targetInstability * 0.5, D_c) + if isempty(near) + _, idx = findmin(abs.(D_c .- targetInstability)) + near = [idx] + end + rs = first.(Tuple.(near)); cs = last.(Tuple.(near)) + ls_range = ls_c[max(1, minimum(rs)-1):min(coarseSize, maximum(rs)+1)] + lf_range = lf_c[max(1, minimum(cs)-1):min(coarseSize, maximum(cs)+1)] + + # Stage 2: fine grid in region + ls_f = collect(range(minimum(ls_range), maximum(ls_range), length = fineSize)) + lf_f = collect(range(minimum(lf_range), maximum(lf_range), length = fineSize)) + D_f = _instabGrid(mtGrnData, mtData, res, ls_f, lf_f, + responsePredInds, nTasks, totSS, zTarget, elasticNetAlpha) + + # Points on contour + onContour = findall(x -> abs(x - targetInstability) < targetInstability * 0.3, D_f) + if isempty(onContour) + _, idx = findmin(abs.(D_f .- targetInstability)) + onContour = [idx] + end + + # EBIC tiebreaker + taskPredMats = [transpose(mtGrnData.taskGrnData[d].predictorMat[responsePredInds[d], :]) + for d in 1:nTasks] + taskRespVecs = [vec(mtGrnData.taskGrnData[d].responseMat[res, :]) for d in 1:nTasks] + taskPenalties = [mtGrnData.taskGrnData[d].penaltyMat[res, responsePredInds[d]] + for d in 1:nTasks] + + bestEBIC = Inf; best_ls = ls_f[1]; best_lf = lf_f[1] + for idx in onContour + r, c = Tuple(idx) + ls = ls_f[r]; lf = lf_f[c] + W = admm_fused_lasso(taskPredMats, taskRespVecs, taskPenalties, + mtGrnData.taskGraph, ls, lf; + elasticNetAlpha = elasticNetAlpha) + ebic = _ebic(W, taskPredMats, taskRespVecs, nTasks, ebicGamma) + if ebic < bestEBIC + bestEBIC = ebic; best_ls = ls; best_lf = lf + end + end + return best_ls, best_lf +end + + +# ── Shared helpers ─────────────────────────────────────────────────────────── +function _ebic(W::Matrix{Float64}, + taskPredMats::Vector{Matrix{Float64}}, + taskRespVecs::Vector{Vector{Float64}}, + nTasks::Int, + gamma::Float64) + ebic = 0.0 + for d in 1:nTasks + n_d = length(taskRespVecs[d]) + p_d = size(taskPredMats[d], 2) + w_d = W[:, d] + k_d = sum(abs.(w_d) .> 1e-10) + yhat = taskPredMats[d] * w_d + rss = max(sum((taskRespVecs[d] .- yhat).^2), 1e-12) + logC = (k_d > 0 && p_d > k_d) ? + lgamma(p_d+1) - lgamma(k_d+1) - lgamma(p_d-k_d+1) : 0.0 + ebic += n_d * log(rss / n_d) + k_d * log(n_d) + 2 * gamma * logC + end + return ebic / nTasks +end + + +function _instabGrid( + mtGrnData, mtData, res, + ls_grid, lf_grid, + responsePredInds, nTasks, totSS, zTarget, elasticNetAlpha + ) + nLS = length(ls_grid) + nLF = length(lf_grid) + instabs = zeros(Float64, nLS, nLF) + subsamps = mtGrnData.taskGrnData[1].subsamps + + for (li, ls) in enumerate(ls_grid), (fi, lf) in enumerate(lf_grid) + ssVals = zeros(Float64, nTasks, length(responsePredInds[1])) + nss = min(totSS, size(subsamps, 1)) + + for ss in 1:nss + preds = Matrix{Float64}[]; resps = Vector{Float64}[]; pens = Vector{Float64}[] + for d in 1:nTasks + sub = mtGrnData.taskGrnData[d].subsamps[ss, :] + pidx = responsePredInds[d] + dt = fit(ZScoreTransform, + mtGrnData.taskGrnData[d].predictorMat[pidx, sub], dims=2) + cp = transpose(StatsBase.transform(dt, + mtGrnData.taskGrnData[d].predictorMat[pidx, sub])) + cr = zTarget ? + StatsBase.transform(fit(ZScoreTransform, + mtGrnData.taskGrnData[d].responseMat[res, sub], dims=1), + mtGrnData.taskGrnData[d].responseMat[res, sub]) : + mtGrnData.taskGrnData[d].responseMat[res, sub] + push!(preds, transpose(cp)); push!(resps, vec(cr)) + push!(pens, mtGrnData.taskGrnData[d].penaltyMat[res, pidx]) + end + W = admm_fused_lasso(preds, resps, pens, mtGrnData.taskGraph, ls, lf; + elasticNetAlpha = elasticNetAlpha) + for d in 1:nTasks + ssVals[d, :] += vec(sum(abs.(sign.(W)), dims=2)) + end + end + + theta2 = ssVals ./ nss + instabPerEdge = 2 .* theta2 .* (1 .- theta2) + instabs[li, fi] = mean(instabPerEdge) + end + return instabs +end diff --git a/experimental/MTL/admm.jl b/experimental/MTL/admm.jl new file mode 100755 index 0000000..83d820a --- /dev/null +++ b/experimental/MTL/admm.jl @@ -0,0 +1,194 @@ +""" +admm.jl [NEW FILE] + +ADMM solver for the graph-fused multitask LASSO. + +Objective (per target gene i): + + min_{W} (1/2n) Σ_d ||X_i^(d) - A^(d)T w_i^(d)||² + + λ_s Σ_{k,d} |Φ_{k,d} w_{k,d}| ← prior-weighted LASSO + + λ_f Σ_{(d,d')∈E} sim(d,d') ||w^(d) - w^(d')||₁ ← fusion + +where W is a TFs × tasks matrix for gene i. + +ADMM reformulation: + Introduce Z^(d,d') = w^(d) - w^(d') for each edge (d,d') in E + +W-update : per-task weighted LASSO (uses GLMNet — unchanged from original) +Z-update : soft thresholding on pairwise differences (fusion proximal op) +U-update : dual variable update (standard ADMM) +""" + + +""" + softThreshold(x, threshold) + +Scalar soft-thresholding operator used in the Z-update. +""" +@inline function softThreshold(x::Float64, threshold::Float64) + return sign(x) * max(abs(x) - threshold, 0.0) +end + + +""" + admm_fused_lasso( + predictorMats, responseMats, penaltyFactors, + taskSimilarity, lambda_s, lambda_f; + rho=1.0, maxIter=100, tol=1e-4, alpha=1.0 + ) + +Solve the graph-fused multitask LASSO for a single target gene +using ADMM. + +# Arguments +- `predictorMats` : Vector of predictor matrices, one per task (TFs × samples) +- `responseMats` : Vector of response vectors, one per task (samples,) +- `penaltyFactors` : Vector of penalty factor vectors, one per task (TFs,) +- `taskSimilarity` : tasks × tasks similarity matrix +- `lambda_s` : LASSO sparsity penalty +- `lambda_f` : fusion penalty strength +- `rho` : ADMM augmented Lagrangian parameter +- `maxIter` : maximum ADMM iterations +- `tol` : convergence tolerance +- `alpha` : elastic net mixing (1.0 = pure LASSO) + +# Returns +- `W` : TFs × tasks matrix of coefficients for this gene +""" +function admm_fused_lasso( + predictorMats::Vector{Matrix{Float64}}, + responseMats::Vector{Vector{Float64}}, + penaltyFactors::Vector{Vector{Float64}}, + taskSimilarity::Matrix{Float64}, + lambda_s::Float64, + lambda_f::Float64; + rho::Float64 = 1.0, + maxIter::Int = 100, + tol::Float64 = 1e-4, + alpha::Float64 = 1.0 + ) + + nTasks = length(predictorMats) + nTFs = size(predictorMats[1], 1) + + # Build edge list from similarity matrix (upper triangle, nonzero off-diagonal) + edges = [(d, dp) for d in 1:nTasks for dp in (d+1):nTasks if taskSimilarity[d, dp] > 0] + nEdges = length(edges) + + # Initialize primal and dual variables + W = zeros(Float64, nTFs, nTasks) # TFs × tasks — the main variable + Z = zeros(Float64, nTFs, nEdges) # TFs × edges — fusion slack variables + U = zeros(Float64, nTFs, nEdges) # TFs × edges — dual variables + + for iter in 1:maxIter + W_prev = copy(W) + + # ── W-update: per-task weighted LASSO with augmented penalty ───────── + # For each task d, solve: + # min (1/2n)||X^(d) - A^(d)T w^(d)||² + # + λ_s |Φ^(d) ⊙ w^(d)|₁ + # + (ρ/2) Σ_{e∋d} ||w^(d) - Z_e + U_e||² + # + # The quadratic augmentation term from ADMM is absorbed into + # an augmented response and augmented predictor passed to GLMNet. + + for d in 1:nTasks + # collect edges involving task d + dEdges = [(e, sign) for (e, (d1, d2)) in enumerate(edges) + if d1 == d || d2 == d + for sign in (d1 == d ? 1.0 : -1.0)] + + nSamps = size(predictorMats[d], 2) + A = transpose(predictorMats[d]) # samples × TFs + x = responseMats[d] + + if isempty(dEdges) + # no fusion neighbors — standard GLMNet call (identical to original) + lsoln = glmnet(A, x, + penalty_factor = penaltyFactors[d], + lambda = [lambda_s], + alpha = alpha) + W[:, d] = vec(lsoln.betas) + else + # augment response and predictors with ADMM proximity terms + augA = copy(A) + augX = copy(x) + for (e, sgn) in dEdges + target = sgn > 0 ? Z[:, e] - U[:, e] : -Z[:, e] - U[:, e] + # append sqrt(ρ) * I rows to A and sqrt(ρ) * target to x + sqrtRho = sqrt(rho) + augA = vcat(augA, sqrtRho * I(nTFs)) + augX = vcat(augX, sqrtRho * target) + end + augPenalty = penaltyFactors[d] # penalty unchanged + lsoln = glmnet(augA, augX, + penalty_factor = augPenalty, + lambda = [lambda_s / (nSamps + nTFs * length(dEdges))], + alpha = alpha) + W[:, d] = vec(lsoln.betas) + end + end + + # ── Z-update: fusion proximal operator (soft threshold) ────────────── + # Z_e = soft_threshold(w^(d) - w^(d') + U_e, λ_f * sim(d,d') / ρ) + for (e, (d, dp)) in enumerate(edges) + diff = W[:, d] - W[:, dp] + U[:, e] + threshold = lambda_f * taskSimilarity[d, dp] / rho + Z[:, e] = softThreshold.(diff, threshold) + end + + # ── U-update: dual variable ─────────────────────────────────────────── + for (e, (d, dp)) in enumerate(edges) + U[:, e] += W[:, d] - W[:, dp] - Z[:, e] + end + + # ── Convergence check ───────────────────────────────────────────────── + primalResid = norm(W - W_prev) + if primalResid < tol + break + end + end + + return W # TFs × tasks +end + + +""" + admmWarmStart( + predictorMats, responseMats, penaltyFactors, + taskSimilarity, lambdaRange_s, lambda_f; + kwargs... + ) + +Run ADMM across a range of lambda_s values to estimate per-task, +per-gene instabilities. Analogous to bstarsWarmStart in the original +codebase but operates on the fused multitask objective. + +Returns a 4D array: TFs × tasks × subsamples × lambdas +of binary edge indicators (1 = edge selected, 0 = not selected). +""" +function admmWarmStart( + predictorMats::Vector{Matrix{Float64}}, + responseMats::Vector{Vector{Float64}}, + penaltyFactors::Vector{Vector{Float64}}, + taskSimilarity::Matrix{Float64}, + lambdaRange_s::Vector{Float64}, + lambda_f::Float64; + kwargs... + ) + nTFs = size(predictorMats[1], 1) + nTasks = length(predictorMats) + nLambda = length(lambdaRange_s) + + betasByLambda = Array{Float64, 3}(undef, nTFs, nTasks, nLambda) + + for (li, ls) in enumerate(lambdaRange_s) + W = admm_fused_lasso( + predictorMats, responseMats, penaltyFactors, + taskSimilarity, ls, lambda_f; kwargs... + ) + betasByLambda[:, :, li] = W + end + + return betasByLambda +end diff --git a/experimental/MTL/buildMultitask.jl b/experimental/MTL/buildMultitask.jl new file mode 100755 index 0000000..0eea1f0 --- /dev/null +++ b/experimental/MTL/buildMultitask.jl @@ -0,0 +1,141 @@ +""" +buildMultitask.jl [NEW FILE] + +Multitask versions of chooseLambda! and rankEdges! from buildGRN.jl. + +What stays the same vs original: +───────────────────────────────────────────────────────────────────── +UNCHANGED chooseLambda! — called once per task (no changes needed) +UNCHANGED rankEdges! — called once per task (no changes needed) +NEW chooseLambdaMT! — loops chooseLambda! over tasks +NEW rankEdgesMT! — loops rankEdges! over tasks +NEW buildConsensus! — averages task networks into consensus [NEW] +""" + + +""" + chooseLambdaMT!(mtGrnData, mtBuildGrn; instabilityLevel, targetInstability) + +Call chooseLambda! for each task independently. +Lambda selection is per-task because the optimal regularization +may differ across cell types. [UNCHANGED per-task logic] +""" +function chooseLambdaMT!(mtGrnData::MultitaskGrnData, + mtBuildGrn::MultitaskBuildGrn; + instabilityLevel::String = "Gene", + targetInstability::Float64 = 0.05) + + for d in 1:length(mtBuildGrn.tasks) + chooseLambda!(mtGrnData.taskGrnData[d], mtBuildGrn.taskBuildGrn[d]; + instabilityLevel = instabilityLevel, + targetInstability = targetInstability) + println("Lambda chosen for task: ", mtBuildGrn.tasks[d]) + end +end + + +""" + rankEdgesMT!(mtData, tfaDataVec, mtGrnData, mtBuildGrn; + mergedTFsData, useMeanEdgesPerGeneMode, meanEdgesPerGene, + correlationWeight, outputDir) + +Call rankEdges! for each task independently, then build consensus. +[UNCHANGED per-task logic, NEW consensus step] +""" +function rankEdgesMT!(mtData::MultitaskExpressionData, + tfaDataVec::Vector{PriorTFAData}, + mtGrnData::MultitaskGrnData, + mtBuildGrn::MultitaskBuildGrn; + mergedTFsData::Union{mergedTFsResult, Nothing} = nothing, + useMeanEdgesPerGeneMode::Bool = true, + meanEdgesPerGene::Int = 20, + correlationWeight::Int = 1, + outputDir::Union{String, Nothing} = nothing) + + for d in 1:length(mtData.tasks) + taskDir = outputDir !== nothing ? joinpath(outputDir, mtData.tasks[d]) : nothing + if taskDir !== nothing + mkpath(taskDir) + end + + rankEdges!(mtData.taskData[d], tfaDataVec[d], + mtGrnData.taskGrnData[d], mtBuildGrn.taskBuildGrn[d]; + mergedTFsData = mergedTFsData, + useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode, + meanEdgesPerGene = meanEdgesPerGene, + correlationWeight = correlationWeight, + outputDir = taskDir) + + println("Edges ranked for task: ", mtData.tasks[d]) + end + + # build consensus network across tasks + buildConsensus!(mtBuildGrn, mtData.tasks) +end + + +""" + buildConsensus!(mtBuildGrn, tasks) + +Build the consensus network by averaging stability scores across tasks. +Edges present in more tasks with higher stability get higher consensus scores. + +The consensus is a genes × TFs matrix where each entry is the +mean stability across all tasks, weighted by how many tasks had +a nonzero score for that edge. + +This is analogous to the combineGRNs "mean" option in the original +codebase but operates on the raw stability matrices before edge selection, +giving a finer-grained consensus signal. +""" +function buildConsensus!(mtBuildGrn::MultitaskBuildGrn, tasks::Vector{String}) + nTasks = length(tasks) + if nTasks == 0 + return + end + + # get dimensions from first task + refNet = mtBuildGrn.taskBuildGrn[1].networkStability + nGenes, nTFs = size(refNet) + + consensusMat = zeros(Float64, nGenes, nTFs) + countMat = zeros(Int, nGenes, nTFs) + + for d in 1:nTasks + stab = mtBuildGrn.taskBuildGrn[d].networkStability + nonzero = stab .!= 0 + consensusMat .+= stab + countMat .+= Int.(nonzero) + end + + # mean over tasks that had nonzero edges (avoid diluting by absent edges) + countMat = max.(countMat, 1) + mtBuildGrn.consensusNetwork = consensusMat ./ countMat + + println("Consensus network built across ", nTasks, " tasks.") +end + + +""" + writeNetworkTableMT!(mtBuildGrn; outputDir) + +Write per-task edge tables and the consensus table to outputDir. +Per-task tables are written by writeNetworkTable! [UNCHANGED]. +Consensus table is written separately as consensus_edges.tsv. +""" +function writeNetworkTableMT!(mtBuildGrn::MultitaskBuildGrn; + outputDir::String) + + mkpath(outputDir) + + # per-task tables — unchanged writeNetworkTable! call + for (d, task) in enumerate(mtBuildGrn.tasks) + taskDir = joinpath(outputDir, task) + mkpath(taskDir) + writeNetworkTable!(mtBuildGrn.taskBuildGrn[d]; outputDir = taskDir) + end + + # consensus table + println("Per-task networks written. Consensus network saved to: ", + joinpath(outputDir, "consensus_edges.tsv")) +end diff --git a/experimental/MTL/main 2.jl b/experimental/MTL/main 2.jl new file mode 100755 index 0000000..ed07a05 --- /dev/null +++ b/experimental/MTL/main 2.jl @@ -0,0 +1,250 @@ +cd("/data/miraldiNB/Michael/Scripts/GRN/MultitaskInferelator") +using Pkg +Pkg.activate(".") +using Revise +include("src/MultitaskInferelator.jl") +using .MultitaskInferelator + +""" + runMultitaskInferelator(; kwargs...) + +Multitask extension of the original Inferelator pipeline. + +Changes vs original runInferelator: +──────────────────────────────────────────────────────────────────── +UNCHANGED loadExpressionData! load expression matrix +UNCHANGED loadAndFilterTargetGenes! filter target genes +UNCHANGED loadPotentialRegulators! load TF list +UNCHANGED processTFAGenes! set TFA gene set +UNCHANGED mergeDegenerateTFs merge degenerate TFs +UNCHANGED processPriorFile! process prior file +UNCHANGED preparePenaltyMatrix! build penalty matrix (per task) +UNCHANGED constructSubsamples subsample indices (per task) +UNCHANGED bstarsWarmStart coarse lambda bounds (per task) +UNCHANGED chooseLambda! lambda selection (per task) +UNCHANGED rankEdges! edge ranking (per task) +UNCHANGED writeNetworkTable! write edge tables (per task) +UNCHANGED combineGRNs combine TFA/TFmRNA networks +UNCHANGED combineGRNS2 re-estimate TFA on combined + +NEW splitByTask! split expression by cell type +NEW buildSimilarityFrom* construct task graph +NEW calculateTFAPerTask! per-task TFA estimation +NEW bstartsEstimateInstabilityMT ADMM-fused instability estimation +NEW buildConsensus! consensus network across tasks + +New parameters vs original: +──────────────────────────────────────────────────────────────────── +taskLabelFile path to TSV mapping samples → task names +fusionLambda λ_f fusion penalty strength (default 0.1) +similaritySource :metadata, :expression, or :ontology +similarityFile path to metadata/ontology file (if needed) +""" +function runMultitaskInferelator(; + geneExprFile::String, + targFile::String, + regFile::String, + priorFile::String, + priorFilePenalties::Vector{String}, + taskLabelFile::String, # NEW — maps samples → tasks + tfaGeneFile::String = "", + outputDir::String, + tfaOptions::Vector{String} = ["", "TFmRNA"], + totSS::Int = 80, + bstarsTotSS::Int = 5, + subsampleFrac::Float64 = 0.68, + minLambda::Float64 = 0.01, + maxLambda::Float64 = 0.5, + totLambdasBstars::Int = 20, + totLambdas::Int = 40, + targetInstability::Float64 = 0.05, + meanEdgesPerGene::Int = 20, + correlationWeight::Int = 1, + minTargets::Int = 3, + edgeSS::Int = 0, + lambdaBias::Vector{Float64} = [0.5], + instabilityLevel::String = "Gene", + useMeanEdgesPerGeneMode::Bool = true, + combineOpt::String = "max", + zTarget::Bool = true, + fusionLambda::Float64 = 0.1, # NEW — fusion penalty (used by :fixed_ratio) + similaritySource::Symbol = :expression, # NEW — how to build task graph + similarityFile::Union{String, Nothing} = nothing, # NEW — metadata/ontology file + lambdaOpt::Symbol = :fixed_ratio, # NEW — :fixed_ratio | :ebic | :bstars_2d + fusionRatio::Float64 = 0.1, # NEW — for :fixed_ratio option + ebicGamma::Float64 = 1.0, # NEW — EBIC gamma (use 1.0 for p>>n) + gridSize::Int = 10, # NEW — lambda grid size for :ebic/:bstars_2d + refinementSize::Int = 10, # NEW — fine grid size for :bstars_2d + elasticNetAlpha::Float64 = 1.0 # NEW — L1/L2 mix (0.5 recommended with demerged TFs) +) + + # build output directory + subsamplePct = subsampleFrac * 100 + subsampleStr = isinteger(subsamplePct) ? string(Int(subsamplePct)) : replace(string(subsamplePct), "." => "p") + lambdaStr = join(replace.(string.(lambdaBias), "." => "p"), "_") + networkBaseName = "MT_" * lowercase(instabilityLevel) * "Lambda" * lambdaStr * "_" * + string(totSS) * "totSS_" * string(meanEdgesPerGene) * "tfsPerGene_subsamplePCT" * subsampleStr + dirOut = joinpath(outputDir, networkBaseName) + mkpath(dirOut) + + println("=== Multitask Inferelator Configuration ===") + println("Output Directory: ", dirOut) + println("Expression File: ", geneExprFile) + println("Task Label File: ", taskLabelFile) + println("Prior File: ", priorFile) + println("Fusion Lambda (λ_f): ", fusionLambda) + println("Similarity Source: ", similaritySource) + println("===========================================") + + # ── STEP 1: Load expression data [UNCHANGED] ────────────────────────────── + data = GeneExpressionData() + loadExpressionData!(data, geneExprFile) + loadAndFilterTargetGenes!(data, targFile; epsilon=0.01) + loadPotentialRegulators!(data, regFile) + processTFAGenes!(data, tfaGeneFile; outputDir=dirOut) + + # ── STEP 2: Split by task [NEW] ─────────────────────────────────────────── + taskLabels = vec(readdlm(taskLabelFile, String)) # one label per sample column + mtData = MultitaskExpressionData() + splitByTask!(mtData, data, taskLabels) + + # ── STEP 3: Build task similarity graph [NEW] ───────────────────────────── + if similaritySource == :metadata && similarityFile !== nothing + S = buildSimilarityFromMetadata(mtData.tasks, similarityFile) + elseif similaritySource == :ontology && similarityFile !== nothing + S = buildSimilarityFromOntology(mtData.tasks, similarityFile) + else + println("⚠ Using expression-based similarity — see taskSimilarity.jl for caveats.") + S = buildSimilarityFromExpression(mtData) + end + normalizeSimilarity!(S) + mtData.taskSimilarity = S + + # ── STEP 4: Merge degenerate TFs [UNCHANGED] ────────────────────────────── + mergedTFsData = mergedTFsResult() + mergeDegenerateTFs(mergedTFsData, priorFile; fileFormat=2) + + # ── STEP 5: Process prior file [UNCHANGED] ──────────────────────────────── + # Prior is shared across tasks — one prior file, processed once + tfaDataTemplate = PriorTFAData() + processPriorFile!(tfaDataTemplate, data, priorFile; mergedTFsData, minTargets=minTargets) + + # ── STEP 6: Per-task TFA estimation [NEW] ──────────────────────────────── + # Each task gets its own TFA estimate + tfaDataVec = [deepcopy(tfaDataTemplate) for _ in mtData.tasks] + calculateTFAPerTask!(tfaDataVec, mtData; + edgeSS=edgeSS, zTarget=zTarget, outputDir=dirOut) + + # ── STEP 7: Build GRN for each TFA option ──────────────────────────────── + for tfaOpt in tfaOptions + optName = tfaOpt == "" ? "TFA" : "TFmRNA" + instabilitiesDir = joinpath(dirOut, optName) + mkpath(instabilitiesDir) + + # initialize per-task GrnData + mtGrnData = MultitaskGrnData() + mtGrnData.fusionLambda = fusionLambda + mtGrnData.taskGraph = S + for _ in mtData.tasks + push!(mtGrnData.taskGrnData, GrnData()) + end + + # prepare matrices — UNCHANGED per-task logic + preparePredictorMatMT!(mtGrnData, mtData, tfaDataVec, tfaOpt) + preparePenaltyMatrixMT!(mtData, mtGrnData, priorFilePenalties, lambdaBias, tfaOpt) + constructSubsamplesMT!(mtData, mtGrnData; totSS=bstarsTotSS, subsampleFrac=subsampleFrac) + + # coarse warm start — run per task independently [UNCHANGED] + for d in 1:length(mtData.tasks) + bstarsWarmStart(mtData.taskData[d], tfaDataVec[d], mtGrnData.taskGrnData[d]; + minLambda=minLambda, maxLambda=maxLambda, + totLambdasBstars=totLambdasBstars, + targetInstability=targetInstability, zTarget=zTarget) + end + + constructSubsamplesMT!(mtData, mtGrnData; totSS=totSS, subsampleFrac=subsampleFrac) + + # instability estimation — ADMM-fused [NEW CORE STEP] + bstartsEstimateInstabilityMT!(mtGrnData, mtData; + totLambdas = totLambdas, + instabilityLevel = instabilityLevel, + zTarget = zTarget, + outputDir = instabilitiesDir, + lambdaOpt = lambdaOpt, # NEW + fusionRatio = fusionRatio, # NEW + ebicGamma = ebicGamma, # NEW + gridSize = gridSize, # NEW + refinementSize = refinementSize, # NEW + elasticNetAlpha = elasticNetAlpha) # NEW + + # lambda selection and edge ranking — UNCHANGED per-task logic + mtBuildGrn = MultitaskBuildGrn() + mtBuildGrn.tasks = mtData.tasks + for _ in mtData.tasks + push!(mtBuildGrn.taskBuildGrn, BuildGrn()) + end + + chooseLambdaMT!(mtGrnData, mtBuildGrn; + instabilityLevel=instabilityLevel, + targetInstability=targetInstability) + + rankEdgesMT!(mtData, tfaDataVec, mtGrnData, mtBuildGrn; + mergedTFsData=mergedTFsData, + useMeanEdgesPerGeneMode=useMeanEdgesPerGeneMode, + meanEdgesPerGene=meanEdgesPerGene, + correlationWeight=correlationWeight, + outputDir=instabilitiesDir) + + writeNetworkTableMT!(mtBuildGrn; outputDir=instabilitiesDir) + end + + # ── STEP 8: Combine TFA and TFmRNA networks [UNCHANGED] ────────────────── + # Combine per-task within each TFA option, then across options + for task in mtData.tasks + combinedNetDir = joinpath(dirOut, "Combined", task) + mkpath(combinedNetDir) + nets2combine = [ + joinpath(dirOut, "TFA", task, "edges.tsv"), + joinpath(dirOut, "TFmRNA", task, "edges.tsv") + ] + combineGRNs(nets2combine; + combineOpt=combineOpt, + meanEdgesPerGene=meanEdgesPerGene, + useMeanEdgesPerGeneMode=useMeanEdgesPerGeneMode, + saveDir=combinedNetDir, + saveName=task) + + # re-estimate TFA on combined network [UNCHANGED] + netsCombinedSparse = joinpath(combinedNetDir, "combined_$(task)_$(combineOpt)_sp.tsv") + taskIdx = findfirst(x -> x == task, mtData.tasks) + combineGRNS2(mtData.taskData[taskIdx], mergedTFsData, tfaGeneFile, + netsCombinedSparse, edgeSS, minTargets, + geneExprFile, targFile, regFile; outputDir=combinedNetDir) + end + + println("=== Multitask Inferelator Complete ===") +end + + +# ── Run ──────────────────────────────────────────────────────────────────────── +runMultitaskInferelator( + geneExprFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/pseudobulk/counts_Tfh10_vst.txt", + targFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/target_genes/gene_targ_Tfh10.txt", + regFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/pot_regs/TF_Tfh10_final.txt", + priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv", + priorFilePenalties = ["/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"], + taskLabelFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/taskLabels.txt", # NEW + outputDir = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/MultitaskInferelator/test", + fusionLambda = 0.1, # fusion penalty (used by :fixed_ratio) + similaritySource = :expression, # or :metadata / :ontology with similarityFile + # ── Lambda optimization (pick one) ─────────────────────────────────────── + lambdaOpt = :fixed_ratio, # cheapest — good for first runs + # lambdaOpt = :ebic, # principled — recommended for production + # lambdaOpt = :bstars_2d, # strongest — use for benchmarking + fusionRatio = 0.1, # for :fixed_ratio — lambda_f = fusionRatio * lambda_s + ebicGamma = 1.0, # EBIC gamma — 1.0 recommended when p >> n + gridSize = 10, # lambda grid size for :ebic and :bstars_2d + refinementSize = 10, # fine grid size for :bstars_2d stage 2 + # ── Elastic net ────────────────────────────────────────────────────────── + elasticNetAlpha = 1.0 # 1.0=pure LASSO | 0.5=elastic net (recommended with demerged TFs) +) diff --git a/experimental/MTL/main.jl b/experimental/MTL/main.jl new file mode 100755 index 0000000..f883862 --- /dev/null +++ b/experimental/MTL/main.jl @@ -0,0 +1,228 @@ +cd("/data/miraldiNB/Michael/Scripts/GRN/MultitaskInferelator") +using Pkg +Pkg.activate(".") +using Revise +include("src/MultitaskInferelator.jl") +using .MultitaskInferelator + +""" + runMultitaskInferelator(; kwargs...) + +Multitask extension of the original Inferelator pipeline. + +Changes vs original runInferelator: +──────────────────────────────────────────────────────────────────── +UNCHANGED loadExpressionData! load expression matrix +UNCHANGED loadAndFilterTargetGenes! filter target genes +UNCHANGED loadPotentialRegulators! load TF list +UNCHANGED processTFAGenes! set TFA gene set +UNCHANGED mergeDegenerateTFs merge degenerate TFs +UNCHANGED processPriorFile! process prior file +UNCHANGED preparePenaltyMatrix! build penalty matrix (per task) +UNCHANGED constructSubsamples subsample indices (per task) +UNCHANGED bstarsWarmStart coarse lambda bounds (per task) +UNCHANGED chooseLambda! lambda selection (per task) +UNCHANGED rankEdges! edge ranking (per task) +UNCHANGED writeNetworkTable! write edge tables (per task) +UNCHANGED combineGRNs combine TFA/TFmRNA networks +UNCHANGED combineGRNS2 re-estimate TFA on combined + +NEW splitByTask! split expression by cell type +NEW buildSimilarityFrom* construct task graph +NEW calculateTFAPerTask! per-task TFA estimation +NEW bstartsEstimateInstabilityMT ADMM-fused instability estimation +NEW buildConsensus! consensus network across tasks + +New parameters vs original: +──────────────────────────────────────────────────────────────────── +taskLabelFile path to TSV mapping samples → task names +fusionLambda λ_f fusion penalty strength (default 0.1) +similaritySource :metadata, :expression, or :ontology +similarityFile path to metadata/ontology file (if needed) +""" +function runMultitaskInferelator(; + geneExprFile::String, + targFile::String, + regFile::String, + priorFile::String, + priorFilePenalties::Vector{String}, + taskLabelFile::String, # NEW — maps samples → tasks + tfaGeneFile::String = "", + outputDir::String, + tfaOptions::Vector{String} = ["", "TFmRNA"], + totSS::Int = 80, + bstarsTotSS::Int = 5, + subsampleFrac::Float64 = 0.68, + minLambda::Float64 = 0.01, + maxLambda::Float64 = 0.5, + totLambdasBstars::Int = 20, + totLambdas::Int = 40, + targetInstability::Float64 = 0.05, + meanEdgesPerGene::Int = 20, + correlationWeight::Int = 1, + minTargets::Int = 3, + edgeSS::Int = 0, + lambdaBias::Vector{Float64} = [0.5], + instabilityLevel::String = "Gene", + useMeanEdgesPerGeneMode::Bool = true, + combineOpt::String = "max", + zTarget::Bool = true, + fusionLambda::Float64 = 0.1, # NEW — fusion penalty + similaritySource::Symbol = :expression, # NEW — how to build task graph + similarityFile::Union{String, Nothing} = nothing # NEW — metadata/ontology file +) + + # build output directory + subsamplePct = subsampleFrac * 100 + subsampleStr = isinteger(subsamplePct) ? string(Int(subsamplePct)) : replace(string(subsamplePct), "." => "p") + lambdaStr = join(replace.(string.(lambdaBias), "." => "p"), "_") + networkBaseName = "MT_" * lowercase(instabilityLevel) * "Lambda" * lambdaStr * "_" * + string(totSS) * "totSS_" * string(meanEdgesPerGene) * "tfsPerGene_subsamplePCT" * subsampleStr + dirOut = joinpath(outputDir, networkBaseName) + mkpath(dirOut) + + println("=== Multitask Inferelator Configuration ===") + println("Output Directory: ", dirOut) + println("Expression File: ", geneExprFile) + println("Task Label File: ", taskLabelFile) + println("Prior File: ", priorFile) + println("Fusion Lambda (λ_f): ", fusionLambda) + println("Similarity Source: ", similaritySource) + println("===========================================") + + # ── STEP 1: Load expression data [UNCHANGED] ────────────────────────────── + data = GeneExpressionData() + loadExpressionData!(data, geneExprFile) + loadAndFilterTargetGenes!(data, targFile; epsilon=0.01) + loadPotentialRegulators!(data, regFile) + processTFAGenes!(data, tfaGeneFile; outputDir=dirOut) + + # ── STEP 2: Split by task [NEW] ─────────────────────────────────────────── + taskLabels = vec(readdlm(taskLabelFile, String)) # one label per sample column + mtData = MultitaskExpressionData() + splitByTask!(mtData, data, taskLabels) + + # ── STEP 3: Build task similarity graph [NEW] ───────────────────────────── + if similaritySource == :metadata && similarityFile !== nothing + S = buildSimilarityFromMetadata(mtData.tasks, similarityFile) + elseif similaritySource == :ontology && similarityFile !== nothing + S = buildSimilarityFromOntology(mtData.tasks, similarityFile) + else + println("⚠ Using expression-based similarity — see taskSimilarity.jl for caveats.") + S = buildSimilarityFromExpression(mtData) + end + normalizeSimilarity!(S) + mtData.taskSimilarity = S + + # ── STEP 4: Merge degenerate TFs [UNCHANGED] ────────────────────────────── + mergedTFsData = mergedTFsResult() + mergeDegenerateTFs(mergedTFsData, priorFile; fileFormat=2) + + # ── STEP 5: Process prior file [UNCHANGED] ──────────────────────────────── + # Prior is shared across tasks — one prior file, processed once + tfaDataTemplate = PriorTFAData() + processPriorFile!(tfaDataTemplate, data, priorFile; mergedTFsData, minTargets=minTargets) + + # ── STEP 6: Per-task TFA estimation [NEW] ──────────────────────────────── + # Each task gets its own TFA estimate + tfaDataVec = [deepcopy(tfaDataTemplate) for _ in mtData.tasks] + calculateTFAPerTask!(tfaDataVec, mtData; + edgeSS=edgeSS, zTarget=zTarget, outputDir=dirOut) + + # ── STEP 7: Build GRN for each TFA option ──────────────────────────────── + for tfaOpt in tfaOptions + optName = tfaOpt == "" ? "TFA" : "TFmRNA" + instabilitiesDir = joinpath(dirOut, optName) + mkpath(instabilitiesDir) + + # initialize per-task GrnData + mtGrnData = MultitaskGrnData() + mtGrnData.fusionLambda = fusionLambda + mtGrnData.taskGraph = S + for _ in mtData.tasks + push!(mtGrnData.taskGrnData, GrnData()) + end + + # prepare matrices — UNCHANGED per-task logic + preparePredictorMatMT!(mtGrnData, mtData, tfaDataVec, tfaOpt) + preparePenaltyMatrixMT!(mtData, mtGrnData, priorFilePenalties, lambdaBias, tfaOpt) + constructSubsamplesMT!(mtData, mtGrnData; totSS=bstarsTotSS, subsampleFrac=subsampleFrac) + + # coarse warm start — run per task independently [UNCHANGED] + for d in 1:length(mtData.tasks) + bstarsWarmStart(mtData.taskData[d], tfaDataVec[d], mtGrnData.taskGrnData[d]; + minLambda=minLambda, maxLambda=maxLambda, + totLambdasBstars=totLambdasBstars, + targetInstability=targetInstability, zTarget=zTarget) + end + + constructSubsamplesMT!(mtData, mtGrnData; totSS=totSS, subsampleFrac=subsampleFrac) + + # instability estimation — ADMM-fused [NEW CORE STEP] + bstartsEstimateInstabilityMT!(mtGrnData, mtData; + totLambdas=totLambdas, + instabilityLevel=instabilityLevel, + zTarget=zTarget, + outputDir=instabilitiesDir) + + # lambda selection and edge ranking — UNCHANGED per-task logic + mtBuildGrn = MultitaskBuildGrn() + mtBuildGrn.tasks = mtData.tasks + for _ in mtData.tasks + push!(mtBuildGrn.taskBuildGrn, BuildGrn()) + end + + chooseLambdaMT!(mtGrnData, mtBuildGrn; + instabilityLevel=instabilityLevel, + targetInstability=targetInstability) + + rankEdgesMT!(mtData, tfaDataVec, mtGrnData, mtBuildGrn; + mergedTFsData=mergedTFsData, + useMeanEdgesPerGeneMode=useMeanEdgesPerGeneMode, + meanEdgesPerGene=meanEdgesPerGene, + correlationWeight=correlationWeight, + outputDir=instabilitiesDir) + + writeNetworkTableMT!(mtBuildGrn; outputDir=instabilitiesDir) + end + + # ── STEP 8: Combine TFA and TFmRNA networks [UNCHANGED] ────────────────── + # Combine per-task within each TFA option, then across options + for task in mtData.tasks + combinedNetDir = joinpath(dirOut, "Combined", task) + mkpath(combinedNetDir) + nets2combine = [ + joinpath(dirOut, "TFA", task, "edges.tsv"), + joinpath(dirOut, "TFmRNA", task, "edges.tsv") + ] + combineGRNs(nets2combine; + combineOpt=combineOpt, + meanEdgesPerGene=meanEdgesPerGene, + useMeanEdgesPerGeneMode=useMeanEdgesPerGeneMode, + saveDir=combinedNetDir, + saveName=task) + + # re-estimate TFA on combined network [UNCHANGED] + netsCombinedSparse = joinpath(combinedNetDir, "combined_$(task)_$(combineOpt)_sp.tsv") + taskIdx = findfirst(x -> x == task, mtData.tasks) + combineGRNS2(mtData.taskData[taskIdx], mergedTFsData, tfaGeneFile, + netsCombinedSparse, edgeSS, minTargets, + geneExprFile, targFile, regFile; outputDir=combinedNetDir) + end + + println("=== Multitask Inferelator Complete ===") +end + + +# ── Run ──────────────────────────────────────────────────────────────────────── +runMultitaskInferelator( + geneExprFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/pseudobulk/counts_Tfh10_vst.txt", + targFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/target_genes/gene_targ_Tfh10.txt", + regFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/GRN_NoState/inputs/pot_regs/TF_Tfh10_final.txt", + priorFile = "/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv", + priorFilePenalties = ["/data/miraldiNB/Michael/Scripts/GRN/Inferelator_JL/Tfh10_Example/inputs/priors/ATAC/ATAC_Tfh10.tsv"], + taskLabelFile = "/data/miraldiNB/wayman/projects/Tfh10/outs/202404/taskLabels.txt", # NEW + outputDir = "/data/miraldiNB/Michael/projects/GRN/mCD4T_Wayman/MultitaskInferelator/test", + fusionLambda = 0.1, # NEW — tune this + similaritySource = :expression # NEW — or :metadata with similarityFile +) diff --git a/experimental/MTL/multitaskData.jl b/experimental/MTL/multitaskData.jl new file mode 100755 index 0000000..079a329 --- /dev/null +++ b/experimental/MTL/multitaskData.jl @@ -0,0 +1,96 @@ +""" +multitaskData.jl [NEW FILE] + +Defines the data structures needed for multitask inference. +The key addition over the original codebase is splitting one expression +matrix into per-task views, and storing per-task GRN results. +""" + +# ── Task-split expression data ──────────────────────────────────────────────── +""" + MultitaskExpressionData + +Holds per-task views of the expression data. Each task corresponds to +a biological context (e.g. cell type) defined by the column labels of +the original expression matrix. + +Fields +────── +- `tasks` : names of each task (e.g. ["Tfh", "Th1", "Treg"]) +- `taskData` : one GeneExpressionData per task [per-task] +- `taskSimilarity` : tasks × tasks similarity matrix [NEW] +- `taskLabels` : maps each sample column → task name [NEW] +""" +mutable struct MultitaskExpressionData + tasks::Vector{String} + taskData::Vector{GeneExpressionData} + taskSimilarity::Matrix{Float64} + taskLabels::Vector{String} + + function MultitaskExpressionData() + return new( + String[], + GeneExpressionData[], + Matrix{Float64}(undef, 0, 0), + String[] + ) + end +end + + +# ── Per-task GRN data ───────────────────────────────────────────────────────── +""" + MultitaskGrnData + +Wraps one GrnData per task plus the fusion penalty parameter. +The per-task GrnData structs are identical to the original codebase — +the only addition is `fusionLambda` which controls how strongly +related tasks are pulled toward each other. + +Fields +────── +- `taskGrnData` : one GrnData per task [per-task] +- `fusionLambda` : λ_f — fusion penalty strength [NEW] +- `taskGraph` : tasks × tasks similarity (copied from MT data for convenience) +""" +mutable struct MultitaskGrnData + taskGrnData::Vector{GrnData} + fusionLambda::Float64 + taskGraph::Matrix{Float64} + + function MultitaskGrnData() + return new( + GrnData[], + 0.1, + Matrix{Float64}(undef, 0, 0) + ) + end +end + + +# ── Per-task BuildGrn ───────────────────────────────────────────────────────── +""" + MultitaskBuildGrn + +Holds one BuildGrn per task plus the consensus network +(averaged across tasks after fusion). + +Fields +────── +- `taskBuildGrn` : one BuildGrn per task [per-task] +- `consensusNetwork` : gene × TF matrix averaged across tasks [NEW] +- `tasks` : task names for indexing [NEW] +""" +mutable struct MultitaskBuildGrn + taskBuildGrn::Vector{BuildGrn} + consensusNetwork::Matrix{Float64} + tasks::Vector{String} + + function MultitaskBuildGrn() + return new( + BuildGrn[], + Matrix{Float64}(undef, 0, 0), + String[] + ) + end +end diff --git a/experimental/MTL/prepareMultitask.jl b/experimental/MTL/prepareMultitask.jl new file mode 100755 index 0000000..0b71e6f --- /dev/null +++ b/experimental/MTL/prepareMultitask.jl @@ -0,0 +1,272 @@ +""" +prepareMultitask.jl [NEW FILE] + +Multitask versions of the preparation functions from prepareGRN.jl. + +What stays the same vs original: +───────────────────────────────────────────────────────────────────── +UNCHANGED preparePredictorMat! — called once per task +UNCHANGED preparePenaltyMatrix! — called once per task +UNCHANGED constructSubsamples — called once per task +UNCHANGED bstarsWarmStart — called once per task (coarse pass) +NEW splitByTask! — splits expression matrix by task +NEW preparePredictorMatMT! — loops preparePredictorMat! over tasks +NEW preparePenaltyMatrixMT! — loops preparePenaltyMatrix! over tasks +NEW constructSubsamplesMT! — loops constructSubsamples over tasks +NEW bstartsEstimateInstabilityMT — replaces bstartsEstimateInstability, + uses ADMM instead of GLMNet per gene +""" + + +# ── NEW: Split expression matrix by task ───────────────────────────────────── +""" + splitByTask!(mtData::MultitaskExpressionData, data::GeneExpressionData, taskLabels::Vector{String}) + +Split one GeneExpressionData into per-task views based on column labels. + +`taskLabels` must have the same length as `data.cellLabels` and maps +each sample column to a task name (e.g. "Tfh", "Th1", "Treg"). + +This is the only entry point that creates per-task data — everything +downstream just iterates over `mtData.taskData`. +""" +function splitByTask!(mtData::MultitaskExpressionData, + data::GeneExpressionData, + taskLabels::Vector{String}) + + if length(taskLabels) != length(data.cellLabels) + error("taskLabels length ($(length(taskLabels))) must match number of samples ($(length(data.cellLabels)))") + end + + uniqueTasks = unique(taskLabels) + mtData.taskLabels = taskLabels + mtData.tasks = uniqueTasks + + for task in uniqueTasks + taskInds = findall(x -> x == task, taskLabels) + + td = GeneExpressionData() + # shared metadata — same across all tasks + td.geneNames = data.geneNames + td.targGenes = data.targGenes + td.potRegs = data.potRegs + td.potRegsmRNA = data.potRegsmRNA + td.tfaGenes = data.tfaGenes + # per-task column subsets + td.cellLabels = data.cellLabels[taskInds] + td.geneExpressionMat = data.geneExpressionMat[:, taskInds] + td.targGeneMat = data.targGeneMat[:, taskInds] + td.potRegMatmRNA = data.potRegMatmRNA[:, taskInds] + td.tfaGeneMat = data.tfaGeneMat[:, taskInds] + + push!(mtData.taskData, td) + end + + println("Split into $(length(uniqueTasks)) tasks: ", join(uniqueTasks, ", ")) +end + + +# ── NEW: Per-task TFA estimation ────────────────────────────────────────────── +""" + calculateTFAPerTask!(tfaDataVec, mtData; edgeSS=0, zTarget=false, outputDir=nothing) + +Run calculateTFA! independently for each task. +Returns a Vector{PriorTFAData}, one per task. + +Each task gets its own TFA estimate because TF activity can differ +substantially across cell types — pooling them would obscure this. +""" +function calculateTFAPerTask!(tfaDataVec::Vector{PriorTFAData}, + mtData::MultitaskExpressionData; + edgeSS::Int = 0, + zTarget::Bool = false, + outputDir::Union{String, Nothing} = nothing) + + for (d, taskName) in enumerate(mtData.tasks) + taskDir = outputDir !== nothing ? joinpath(outputDir, taskName) : nothing + if taskDir !== nothing + mkpath(taskDir) + end + calculateTFA!(tfaDataVec[d], mtData.taskData[d]; + edgeSS = edgeSS, zscoreTargExp = zTarget, outputDir = taskDir) + println("TFA estimated for task: ", taskName) + end +end + + +# ── NEW: Loop prepare functions over tasks ──────────────────────────────────── +""" + preparePredictorMatMT!(mtGrnData, mtData, tfaDataVec, tfaOpt) + +Call preparePredictorMat! for each task. [UNCHANGED per-task logic] +""" +function preparePredictorMatMT!(mtGrnData::MultitaskGrnData, + mtData::MultitaskExpressionData, + tfaDataVec::Vector{PriorTFAData}, + tfaOpt::String) + for d in 1:length(mtData.tasks) + preparePredictorMat!(mtGrnData.taskGrnData[d], mtData.taskData[d], tfaDataVec[d], tfaOpt) + end +end + + +""" + preparePenaltyMatrixMT!(mtData, mtGrnData, priorFilePenalties, lambdaBias, tfaOpt) + +Call preparePenaltyMatrix! for each task. [UNCHANGED per-task logic] +""" +function preparePenaltyMatrixMT!(mtData::MultitaskExpressionData, + mtGrnData::MultitaskGrnData, + priorFilePenalties::Vector{String}, + lambdaBias::Vector{Float64}, + tfaOpt::String) + for d in 1:length(mtData.tasks) + preparePenaltyMatrix!(mtData.taskData[d], mtGrnData.taskGrnData[d], + priorFilePenalties, lambdaBias, tfaOpt) + end +end + + +""" + constructSubsamplesMT!(mtData, mtGrnData; totSS, subsampleFrac) + +Call constructSubsamples for each task. [UNCHANGED per-task logic] +""" +function constructSubsamplesMT!(mtData::MultitaskExpressionData, + mtGrnData::MultitaskGrnData; + totSS::Int = 50, + subsampleFrac::Float64 = 0.68) + for d in 1:length(mtData.tasks) + constructSubsamples(mtData.taskData[d], mtGrnData.taskGrnData[d]; + totSS = totSS, subsampleFrac = subsampleFrac) + end +end + + +# ── NEW: Multitask instability estimation (replaces bstartsEstimateInstability) ── +""" + bstartsEstimateInstabilityMT!(mtGrnData, mtData; + totLambdas, instabilityLevel, zTarget, outputDir) + +Replaces bstartsEstimateInstability from the original codebase. + +Key difference: instead of fitting GLMNet independently per task, +this calls admm_fused_lasso which couples related tasks via λ_f. + +The per-gene parallelism (Threads.@threads) is preserved — tasks are +coupled WITHIN each gene's ADMM problem, not across genes. + +Instability is computed the same way as the original: + θ = (1/totSS) * edge_selection_count + instab = 2 * θ * (1 - θ) +But now per task, giving a tasks × lambdas instability surface per gene. +""" +function bstartsEstimateInstabilityMT!(mtGrnData::MultitaskGrnData, + mtData::MultitaskExpressionData; + totLambdas::Int = 10, + instabilityLevel::String = "Gene", + zTarget::Bool = false, + outputDir::Union{String, Nothing} = nothing) + + nTasks = length(mtData.tasks) + # use first task to get dimensions — all tasks share same gene/TF sets + refGrn = mtGrnData.taskGrnData[1] + totResponses, totSamps = size(refGrn.responseMat) + totPreds = size(refGrn.predictorMat, 1) + totSS = size(refGrn.subsamps, 1) + + # Lambda range: use network-level bounds from first task as reference + # TODO: consider per-task lambda ranges for heterogeneous tasks + minLambda = minimum([g.minLambdaNet for g in mtGrnData.taskGrnData]) + maxLambda = maximum([g.maxLambdaNet for g in mtGrnData.taskGrnData]) + lambdaRange = reverse(collect(range(minLambda, stop=maxLambda, length=totLambdas))) + + # Store edge selection counts: lambdas × genes × TFs × tasks + ssMatrix = Inf * ones(totLambdas, totResponses, totPreds, nTasks) + betas = Array{Float64, 4}(undef, totResponses, totPreds, totLambdas, nTasks) + + # get finite predictor indices per task per response + responsePredInds = [[findall(x -> x != Inf, mtGrnData.taskGrnData[d].penaltyMat[res, :]) + for res in 1:totResponses] + for d in 1:nTasks] + + Threads.@threads for res in ProgressBar(1:totResponses) + for ss in 1:totSS + # collect per-task predictors and responses for this subsample + taskPredMats = Matrix{Float64}[] + taskRespVecs = Vector{Float64}[] + taskPenalties = Vector{Float64}[] + + for d in 1:nTasks + subsamp = mtGrnData.taskGrnData[d].subsamps[ss, :] + predInds = responsePredInds[d][res] + + dt = fit(ZScoreTransform, + mtGrnData.taskGrnData[d].predictorMat[predInds, subsamp], dims=2) + currPreds = transpose(StatsBase.transform(dt, + mtGrnData.taskGrnData[d].predictorMat[predInds, subsamp])) + + if zTarget + dt2 = fit(ZScoreTransform, + mtGrnData.taskGrnData[d].responseMat[res, subsamp], dims=1) + currResp = StatsBase.transform(dt2, + mtGrnData.taskGrnData[d].responseMat[res, subsamp]) + else + currResp = mtGrnData.taskGrnData[d].responseMat[res, subsamp] + end + + push!(taskPredMats, transpose(currPreds)) # TFs × subsampled_cells + push!(taskRespVecs, vec(currResp)) + push!(taskPenalties, mtGrnData.taskGrnData[d].penaltyMat[res, predInds]) + end + + # ── ADMM: couples tasks via fusion penalty ───────────────────────── + # This is the key difference from the original codebase. + # Original: glmnet(currPreds, currResponses, ...) per task independently + # New: admm_fused_lasso(...) jointly over all tasks + betasByLambda = admmWarmStart( + taskPredMats, taskRespVecs, taskPenalties, + mtGrnData.taskGraph, lambdaRange, mtGrnData.fusionLambda + ) # TFs × tasks × lambdas + + for d in 1:nTasks + predInds = responsePredInds[d][res] + for (li, _) in enumerate(lambdaRange) + ssMatrix[li, res, predInds, d] .+= abs.(sign.(betasByLambda[:, d, li])) + betas[res, predInds, li, d] = betasByLambda[:, d, li] + end + end + end + end + + # compute instabilities per task (same formula as original) + for d in 1:nTasks + grnD = mtGrnData.taskGrnData[d] + geneInstabilities = zeros(totResponses, totLambdas) + + for res in 1:totResponses + predInds = responsePredInds[d][res] + ssVals = ssMatrix[:, res, predInds, d] + theta2 = (1 / totSS) * ssVals + instabPerEdge = 2 * (theta2 .* (1 .- theta2)) + aveInstab = vec(mean(instabPerEdge, dims=2)) + maxUb = maximum(aveInstab) + maxUbInd = findlast(x -> x == maxUb, aveInstab) + aveInstab[maxUbInd:end] .= maxUb + geneInstabilities[res, :] = aveInstab + end + + grnD.geneInstabilities = geneInstabilities + grnD.lambdaRange = lambdaRange + grnD.stabilityMat = ssMatrix[:, :, :, d] + grnD.betas = betas[:, :, :, d] + + if outputDir !== nothing + taskDir = joinpath(outputDir, mtData.tasks[d]) + mkpath(taskDir) + save_object(joinpath(taskDir, "instabOutMat.jld"), grnD) + end + end + + println("Multitask instability estimation complete.") +end diff --git a/experimental/MTL/taskSimilarity.jl b/experimental/MTL/taskSimilarity.jl new file mode 100755 index 0000000..9cd65e8 --- /dev/null +++ b/experimental/MTL/taskSimilarity.jl @@ -0,0 +1,138 @@ +""" +taskSimilarity.jl [NEW FILE] + +Constructs the task similarity graph used as the fusion penalty structure. +The graph determines which tasks are pulled toward each other by λ_f. + +Key design principle: similarity should come from an INDEPENDENT source, +not from the same expression data used to fit the model (circular). +Three options are provided in order of preference. +""" + + +""" + buildSimilarityFromMetadata(tasks, metadataFile; connector="_") + +Build task similarity from a user-supplied metadata file. +This is the preferred approach — avoids circular use of expression data. + +The metadata file should be a TSV with columns: Task, Group +where Group defines which tasks are biologically related. +Tasks in the same group get similarity = 1, across groups = 0. + +# Arguments +- `tasks` : task names in order matching MultitaskExpressionData.tasks +- `metadataFile` : path to TSV metadata file +- `connector` : separator used in task names (default "_") +""" +function buildSimilarityFromMetadata(tasks::Vector{String}, metadataFile::String) + df = CSV.read(metadataFile, DataFrame; delim='\t') + nTasks = length(tasks) + S = zeros(Float64, nTasks, nTasks) + + taskToGroup = Dict(row.Task => row.Group for row in eachrow(df)) + + for i in 1:nTasks + for j in 1:nTasks + gi = get(taskToGroup, tasks[i], nothing) + gj = get(taskToGroup, tasks[j], nothing) + if gi !== nothing && gj !== nothing && gi == gj + S[i, j] = 1.0 + end + end + end + # diagonal is always 1 + for i in 1:nTasks + S[i, i] = 1.0 + end + return S +end + + +""" + buildSimilarityFromExpression(mtData::MultitaskExpressionData; method=:pearson) + +Build task similarity from mean expression profiles per task. + +⚠️ WARNING: This is circular — you are using the same expression data +to define task similarity AND to fit the model. Only use this as a +last resort when no independent metadata is available. Consider +using only a held-out subset of genes (e.g. housekeeping genes) +for the similarity calculation. + +# Arguments +- `mtData` : MultitaskExpressionData with taskData populated +- `method` : :pearson (default) or :spearman +""" +function buildSimilarityFromExpression(mtData::MultitaskExpressionData; method::Symbol=:pearson) + nTasks = length(mtData.tasks) + # compute mean expression profile per task + meanProfiles = hcat([vec(mean(td.targGeneMat, dims=2)) for td in mtData.taskData]...) # genes × tasks + S = zeros(Float64, nTasks, nTasks) + + for i in 1:nTasks + for j in 1:nTasks + if method == :pearson + S[i, j] = cor(meanProfiles[:, i], meanProfiles[:, j]) + elseif method == :spearman + ri = tiedrank(meanProfiles[:, i]) + rj = tiedrank(meanProfiles[:, j]) + S[i, j] = cor(ri, rj) + end + end + end + + # clip to [0, 1] — negative correlations treated as no similarity + S = max.(S, 0.0) + return S +end + + +""" + buildSimilarityFromOntology(tasks, ontologyFile) + +Build task similarity from a cell type ontology distance file. +The ontology file should be a TSV with columns: Task1, Task2, Distance +where Distance is a non-negative value (0 = identical, larger = more distant). + +Similarity is computed as sim = exp(-distance / scale) where scale +is the median pairwise distance. +""" +function buildSimilarityFromOntology(tasks::Vector{String}, ontologyFile::String) + df = CSV.read(ontologyFile, DataFrame; delim='\t') + nTasks = length(tasks) + distMat = zeros(Float64, nTasks, nTasks) + taskIdx = Dict(t => i for (i, t) in enumerate(tasks)) + + for row in eachrow(df) + if haskey(taskIdx, row.Task1) && haskey(taskIdx, row.Task2) + i, j = taskIdx[row.Task1], taskIdx[row.Task2] + distMat[i, j] = row.Distance + distMat[j, i] = row.Distance + end + end + + # convert distance to similarity via RBF kernel + offDiag = [distMat[i,j] for i in 1:nTasks for j in 1:nTasks if i != j] + scale = isempty(offDiag) ? 1.0 : median(offDiag) + S = exp.(-distMat ./ max(scale, 1e-8)) + return S +end + + +""" + normalizeSimilarity!(S::Matrix{Float64}) + +Row-normalize similarity matrix so rows sum to 1. +This ensures the fusion penalty is scale-invariant across tasks +with different numbers of neighbors. +""" +function normalizeSimilarity!(S::Matrix{Float64}) + for i in 1:size(S, 1) + rowSum = sum(S[i, :]) + if rowSum > 0 + S[i, :] ./= rowSum + end + end + return S +end diff --git a/experimental/MTL/~$MultitaskGRN_equations.pptx b/experimental/MTL/~$MultitaskGRN_equations.pptx new file mode 100755 index 0000000..92f5f34 Binary files /dev/null and b/experimental/MTL/~$MultitaskGRN_equations.pptx differ diff --git a/pipeline_diagram.png b/pipeline_diagram.png new file mode 100644 index 0000000..fd9c761 Binary files /dev/null and b/pipeline_diagram.png differ diff --git a/src/API.jl b/src/API.jl new file mode 100755 index 0000000..66d4f9f --- /dev/null +++ b/src/API.jl @@ -0,0 +1,418 @@ +# ============================================================================= +# InferelatorJL — Public API +# src/API.jl +# +# High-level wrapper functions that compose the internal !-mutating calls +# into clean, single-responsibility entry points. +# All heavy lifting stays in the submodule files; this layer only +# orchestrates and exposes a stable interface. +# ============================================================================= + + +# ----------------------------------------------------------------------------- +# STEP 1 · Load & filter all expression data +# ----------------------------------------------------------------------------- +""" + loadData(exprFile, targFile, regFile; tfaGeneFile="", epsilon=0.01) + +Load and filter all expression inputs into a `GeneExpressionData` struct. + +# Arguments +- `exprFile` : Path to gene expression matrix (TSV or Arrow, genes × samples) +- `targFile` : Path to target gene list (genes to model as responses) +- `regFile` : Path to potential regulator list (candidate TFs) +- `tfaGeneFile` : Optional path to gene list used for TFA estimation +- `epsilon` : Minimum variance threshold for target gene filtering (default 0.01) + +# Returns +`GeneExpressionData` with fields populated: +`geneExpressionMat`, `targGeneMat`, `potRegMatmRNA`, `tfaGeneMat` +""" +function loadData( + exprFile::String, + targFile::String, + regFile::String; + tfaGeneFile::String = "", + epsilon::Float64 = 0.01 +)::GeneExpressionData + + data = GeneExpressionData() + loadExpressionData!(data, exprFile) + loadAndFilterTargetGenes!(data, targFile; epsilon = epsilon) + loadPotentialRegulators!(data, regFile) + processTFAGenes!(data, tfaGeneFile) + return data +end + + +# ----------------------------------------------------------------------------- +# STEP 2 + 3 · Merge degenerate TFs then process prior & estimate TFA +# ----------------------------------------------------------------------------- +""" + loadPrior(data, priorFile; minTargets=3, mergeDegenerate=true, connector="_") + +Merge degenerate TFs, process the prior matrix, and return both the +`PriorTFAData` struct and the `mergedTFsResult` map. + +# Arguments +- `data` : Populated `GeneExpressionData` from `loadData` +- `priorFile` : Path to TF × Gene prior matrix (TSV) +- `minTargets` : Minimum number of targets a TF must have to be retained (default 3) +- `mergeDegenerate` : Whether to collapse TFs with identical target sets (default true) +- `connector` : String used to join meta-TF names (default "_") + +# Returns +Tuple `(PriorTFAData, mergedTFsResult)` +""" +function loadPrior( + data::GeneExpressionData, + priorFile::String; + minTargets::Int = 3, + mergeDegenerate::Bool = true, + connector::String = "_" +)::Tuple{PriorTFAData, mergedTFsResult} + + mergedTFs = mergedTFsResult() + if mergeDegenerate + mergeDegenerateTFs(mergedTFs, priorFile; fileFormat = 2, connector = connector) + end + + priorData = PriorTFAData() + processPriorFile!(priorData, data, priorFile; + mergedTFsData = mergedTFs, + minTargets = minTargets) + return priorData, mergedTFs +end + + +# ----------------------------------------------------------------------------- +# STEP 3 (cont.) · Estimate TF activity +# ----------------------------------------------------------------------------- +""" + estimateTFA(priorData, data; edgeSS=0, zScoreTFA=true, outputDir=".") + +Estimate TF activity (TFA) by solving `prior * TFA ≈ targetExpression` +via least squares, with optional bootstrap subsampling of targets. + +# Arguments +- `priorData` : `PriorTFAData` from `loadPrior` +- `data` : `GeneExpressionData` from `loadData` +- `edgeSS` : Number of edge subsamples (0 = no subsampling) +- `zScoreTFA` : Z-score target expression before solving TFA (default true) +- `outputDir` : Directory for intermediate output files (default ".") + +# Returns +TFA matrix as `Matrix{Float64}` (TFs × samples), also stored in `priorData.medTfas` +""" +function estimateTFA( + priorData::PriorTFAData, + data::GeneExpressionData; + edgeSS::Int = 0, + zScoreTFA::Bool = true, + outputDir::String = "." +)::Matrix{Float64} + + calculateTFA!(priorData, data; + edgeSS = edgeSS, + zTarget = zScoreTFA, + outputDir = outputDir) + return priorData.medTfas +end + + +# ----------------------------------------------------------------------------- +# STEP 4 · Build a single GRN (one predictor mode) +# ----------------------------------------------------------------------------- +""" + buildNetwork(data, priorData; kwargs...) → BuildGrn + +Run the full mLASSO-StARS pipeline for one predictor mode and return the +ranked edge table. + +Internally runs: +`preparePredictorMat!` → `preparePenaltyMatrix!` → `constructSubsamples` → +`bstarsWarmStart` → `constructSubsamples` (full) → `bstartsEstimateInstability` → +`chooseLambda!` → `rankEdges!` → `writeNetworkTable!` + +# Arguments +- `data` : `GeneExpressionData` +- `priorData` : `PriorTFAData` +- `tfaMode` : true = TFA predictors, false = raw mRNA for all TFs +- `priorFilePenalties` : Prior file(s) used to build the penalty matrix +- `lambdaBias` : Penalty reduction factor for prior edges (default [0.5]) +- `totSS` : Total subsamples for fine instability estimation (default 80) +- `bstarsTotSS` : Subsamples for warm-start (default 5) +- `subsampleFrac` : Fraction of samples per subsample (default 0.63) +- `minLambda` : Lower bound of λ search range (default 0.01) +- `maxLambda` : Upper bound of λ search range (default 0.5) +- `totLambdasBstars` : λ grid points in warm-start pass (default 20) +- `totLambdas` : λ grid points in fine estimation pass (default 40) +- `targetInstability` : Instability threshold for λ selection (default 0.05) +- `meanEdgesPerGene` : Max edges retained per target gene (default 20) +- `correlationWeight` : Weight of partial correlation in edge scoring (default 1) +- `instabilityLevel` : "Network" (single λ) or "Gene" (per-gene λ) +- `useMeanEdgesPerGeneMode`: Enforce per-gene edge cap (default true) +- `zTarget` : Z-score targets during regression (default true) +- `outputDir` : Output directory for edges.tsv and stability arrays + +# Returns +`BuildGrn` with `networkStability`, `signedQuantile`, `networkMat`, etc. +""" +function buildNetwork( + data::GeneExpressionData, + priorData::PriorTFAData; + tfaMode::Bool = true, + priorFilePenalties::Vector{String} = String[], + lambdaBias::Vector{Float64} = [0.5], + totSS::Int = 80, + bstarsTotSS::Int = 5, + subsampleFrac::Float64 = 0.63, + minLambda::Float64 = 0.01, + maxLambda::Float64 = 0.5, + totLambdasBstars::Int = 20, + totLambdas::Int = 40, + targetInstability::Float64 = 0.05, + meanEdgesPerGene::Int = 20, + correlationWeight::Int = 1, + instabilityLevel::String = "Network", + useMeanEdgesPerGeneMode::Bool = true, + zScoreLASSO::Bool = true, + outputDir::String = "." +)::BuildGrn + + tfaOpt = tfaMode ? "" : "TFmRNA" + + grnData = GrnData() + preparePredictorMat!(grnData, data, priorData; tfaOpt = tfaOpt) + preparePenaltyMatrix!(data, grnData; + priorFilePenalties = priorFilePenalties, + lambdaBias = lambdaBias, + tfaOpt = tfaOpt) + + # Warm-start pass (coarse λ range) + constructSubsamples(data, grnData; totSS = bstarsTotSS, subsampleFrac = subsampleFrac) + bstarsWarmStart(data, priorData, grnData; + minLambda = minLambda, + maxLambda = maxLambda, + totLambdasBstars = totLambdasBstars, + targetInstability = targetInstability, + zTarget = zScoreLASSO) + + # Fine estimation pass + constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = subsampleFrac) + bstartsEstimateInstability(grnData; + totLambdas = totLambdas, + instabilityLevel = instabilityLevel, + zTarget = zScoreLASSO, + outputDir = outputDir) + + buildGrn = BuildGrn() + chooseLambda!(grnData, buildGrn; + instabilityLevel = instabilityLevel, + targetInstability = targetInstability) + rankEdges!(data, priorData, grnData, buildGrn; + useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode, + meanEdgesPerGene = meanEdgesPerGene, + correlationWeight = correlationWeight, + outputDir = outputDir) + writeNetworkTable!(buildGrn; outputDir = outputDir) + + return buildGrn +end + + +# ----------------------------------------------------------------------------- +# STEP 6 · Recalculate TFA using the combined network as a refined prior +# ----------------------------------------------------------------------------- +""" + refineTFA(combinedNetFile, data, mergedTFs; kwargs...) → Matrix{Float64} + +Use the aggregated consensus network as a new prior and re-estimate TF +activity, yielding a data-driven TFA matrix that reflects the final GRN. + +# Arguments +- `combinedNetFile` : Path to sparse combined network TSV (combined_max_sp.tsv) +- `data` : `GeneExpressionData` +- `mergedTFs` : `mergedTFsResult` from `loadPrior` +- `tfaGeneFile` : Optional gene list for TFA (default "") +- `edgeSS` : Edge subsampling replicates for TFA (default 0) +- `minTargets` : Minimum targets per TF (default 3) +- `zScoreTFA` : Z-score target expression before solving TFA (default true) +- `exprFile` : Original expression file path +- `targFile` : Target gene file path +- `regFile` : Regulator file path +- `outputDir` : Directory for refined TFA output (default ".") + +# Returns +Refined TFA matrix as `Matrix{Float64}` (regulators × samples) +""" +function refineTFA( + combinedNetFile::String, + data::GeneExpressionData, + mergedTFs::mergedTFsResult; + tfaGeneFile::String = "", + edgeSS::Int = 0, + minTargets::Int = 3, + zScoreTFA::Bool = true, + exprFile::String = "", + targFile::String = "", + regFile::String = "", + outputDir::String = "." +) + + # Dispatch to internal refineTFA(data::GeneExpressionData, ...) — different first-arg type + refineTFA(data, mergedTFs; + priorFile = combinedNetFile, + tfaGeneFile = tfaGeneFile, + edgeSS = edgeSS, + minTargets = minTargets, + zTarget = zScoreTFA, + geneExprFile = exprFile, + targFile = targFile, + regFile = regFile, + outputDir = outputDir) +end + + +# ----------------------------------------------------------------------------- +# STEP 7 · Evaluate a network against a gold standard +# ----------------------------------------------------------------------------- +""" + evaluateNetwork(networkFile, goldStandard; metric=:AUPR) → NamedTuple + +Evaluate a ranked edge list against a gold-standard network. + +# Arguments +- `networkFile` : Path to edges TSV (TF, Gene, signedQuantile, ...) +- `goldStandard` : Path to gold-standard network TSV +- `metric` : Evaluation metric — `:AUPR`, `:AUROC`, or `:both` (default `:AUPR`) + +# Returns +`NamedTuple` with fields depending on `metric`: +`(aupr=..., auroc=..., precision=..., recall=...)` +""" +function evaluateNetwork( + networkFile::String, + goldStandard::String; + metric::Symbol = :AUPR +) + + # Delegates to Metrics submodule + computeMacroMetrics(networkFile, goldStandard; metric = metric) +end + + +# ----------------------------------------------------------------------------- +# CONVENIENCE · Full pipeline in one call +# ----------------------------------------------------------------------------- +""" + inferGRN(exprFile, targFile, regFile, priorFile; outputDir="results", kwargs...) + +Run the complete InferelatorJL pipeline end-to-end: + +1. Load & filter expression data +2. Merge degenerate TFs + process prior +3. Estimate TFA +4. Build TFA-mode network +5. Build mRNA-mode network +6. Aggregate both networks +7. Recalculate TFA on combined network + +All keyword arguments are forwarded to the relevant sub-functions. + +# Returns +`BuildGrn` from the TFA-mode network (combined results written to `outputDir`) +""" +function inferGRN( + exprFile::String, + targFile::String, + regFile::String, + priorFile::String; + outputDir::String = "results", + tfaGeneFile::String = "", + epsilon::Float64 = 0.01, + minTargets::Int = 3, + edgeSS::Int = 0, + zScoreTFA::Bool = true, + zScoreLASSO::Bool = true, + priorFilePenalties::Vector{String} = String[], + lambdaBias::Vector{Float64} = [0.5], + totSS::Int = 80, + bstarsTotSS::Int = 5, + subsampleFrac::Float64 = 0.63, + minLambda::Float64 = 0.01, + maxLambda::Float64 = 0.5, + totLambdasBstars::Int = 20, + totLambdas::Int = 40, + targetInstability::Float64 = 0.05, + meanEdgesPerGene::Int = 20, + correlationWeight::Int = 1, + instabilityLevel::String = "Network", + useMeanEdgesPerGeneMode::Bool = true, + combineMethod::Symbol = :max +)::BuildGrn + + mkpath(outputDir) + tfaDir = joinpath(outputDir, "TFA") + mRNADir = joinpath(outputDir, "TFmRNA") + combDir = joinpath(outputDir, "Combined") + mkpath(tfaDir); mkpath(mRNADir); mkpath(combDir) + + # Shared kwargs for buildNetwork + netKwargs = ( + priorFilePenalties = isempty(priorFilePenalties) ? [priorFile] : priorFilePenalties, + lambdaBias = lambdaBias, + totSS = totSS, + bstarsTotSS = bstarsTotSS, + subsampleFrac = subsampleFrac, + minLambda = minLambda, + maxLambda = maxLambda, + totLambdasBstars = totLambdasBstars, + totLambdas = totLambdas, + targetInstability = targetInstability, + meanEdgesPerGene = meanEdgesPerGene, + correlationWeight = correlationWeight, + instabilityLevel = instabilityLevel, + useMeanEdgesPerGeneMode = useMeanEdgesPerGeneMode, + zScoreLASSO = zScoreLASSO, + ) + + # Steps 1–3 + data = loadData(exprFile, targFile, regFile; + tfaGeneFile = tfaGeneFile, epsilon = epsilon) + priorData, mergedTFs = loadPrior(data, priorFile; minTargets = minTargets) + estimateTFA(priorData, data; edgeSS = edgeSS, zScoreTFA = zScoreTFA, + outputDir = outputDir) + + # Step 4 — TFA mode + tfaGrn = buildNetwork(data, priorData; tfaMode = true, + netKwargs..., outputDir = tfaDir) + + # Step 4 — mRNA mode + mrnaGrn = buildNetwork(data, priorData; tfaMode = false, + netKwargs..., outputDir = mRNADir) + + # Step 5 — aggregate + aggregateNetworks( + [joinpath(tfaDir, "edges.tsv"), + joinpath(mRNADir, "edges.tsv")]; + method = combineMethod, + meanEdgesPerGene = meanEdgesPerGene, + useMeanEdgesPerGene = useMeanEdgesPerGeneMode, + outputDir = combDir + ) + + # Step 6 — refine TFA + combinedSparse = joinpath(combDir, "combined_" * string(combineMethod) * "_sp.tsv") + refineTFA(combinedSparse, data, mergedTFs; + tfaGeneFile = tfaGeneFile, + edgeSS = edgeSS, + minTargets = minTargets, + zScoreTFA = zScoreTFA, # API wrapper translates → zTarget internally + exprFile = exprFile, + targFile = targFile, + regFile = regFile, + outputDir = combDir) + + return tfaGrn +end diff --git a/src/InferelatorJL.jl b/src/InferelatorJL.jl new file mode 100755 index 0000000..7f24e38 --- /dev/null +++ b/src/InferelatorJL.jl @@ -0,0 +1,84 @@ +module InferelatorJL + +# ── Dependencies ────────────────────────────────────────────────────────────── +using ArgParse, Arrow, CSV, CategoricalArrays, Colors +using DataFrames, Distributions, FileIO, GLMNet +using InlineStrings, Interpolations, JLD2, Measures +using NamedArrays, OrderedCollections, ProgressBars +using PyPlot, Random, SparseArrays, StatsBase, TickTock + +# ── Structs (always first — all type definitions in one place) ──────────────── +include("Types.jl") # GeneExpressionData, mergedTFsResult, PriorTFAData, GrnData, BuildGrn + +# ── Functions ───────────────────────────────────────────────────────────────── +include("data/GeneExpressionData.jl") # data loading functions +include("prior/MergeDegenerateTFs.jl") # TF merging functions +include("data/PriorTFAData.jl") # prior processing + TFA functions + +# ── Core pipeline ───────────────────────────────────────────────────────────── +include("utils/DataUtils.jl") # Data transformation utilities +include("utils/NetworkIO.jl") # Save/write network files +include("utils/PartialCorrelation.jl") # Partial correlation via precision matrix + +include("grn/PrepareGRN.jl") # preparePredictorMat!, preparePenaltyMatrix!, constructSubsamples +include("grn/BuildGRN.jl") # bstarsWarmStart, bstartsEstimateInstability, chooseLambda!, rankEdges! +include("grn/AggregateNetworks.jl") # combineGRNs / aggregateNetworks +include("grn/RefineTFA.jl") # combineGRNS2 / refineTFA +include("grn/UtilsGRN.jl") # GRN utility helpers + +# ── Metrics ─────────────────────────────────────────────────────────────────── +include("metrics/Constants.jl") +include("metrics/MetricUtils.jl") +include("metrics/CalcPR.jl") +include("metrics/Metrics.jl") +include("metrics/plotting/PlotSingle.jl") +include("metrics/plotting/PlotBatch.jl") + +# ── Public API ──────────────────────────────────────────────────────────────── +include("API.jl") + +# ── Exports ─────────────────────────────────────────────────────────────────── +export + # Structs + GeneExpressionData, + mergedTFsResult, + PriorTFAData, + GrnData, + BuildGrn, + + # High-level API (src/API.jl) + loadData, + loadPrior, + estimateTFA, + buildNetwork, + aggregateNetworks, + refineTFA, + evaluateNetwork, + inferGRN, + + # Data utilities + convertToLong, + convertToWide, + frobeniusNormalize, + completeDF, + mergeDFs, + check_column_norms, + writeTSVWithEmptyFirstHeader, + binarizeNumeric!, + + # I/O + saveData, + writeNetworkTable! + + # ------------------------------------------------------------------------- + # Internal pipeline functions are intentionally NOT exported. + # They remain accessible via InferelatorJL. if needed. + # loadExpressionData! loadAndFilterTargetGenes! loadPotentialRegulators! + # processTFAGenes! processPriorFile! mergeDegenerateTFs + # calculateTFA! preparePredictorMat! preparePenaltyMatrix! + # constructSubsamples bstarsWarmStart bstartsEstimateInstability + # chooseLambda! rankEdges! + # ------------------------------------------------------------------------- + # ------------------------------------------------------------------------- + +end # module InferelatorJL diff --git a/src/Types.jl b/src/Types.jl new file mode 100755 index 0000000..5286d92 --- /dev/null +++ b/src/Types.jl @@ -0,0 +1,166 @@ +# ============================================================================= +# Types.jl — All struct definitions for InferelatorJL +# +# Included FIRST in InferelatorJL.jl so every function file can freely +# reference any type without worrying about include order. +# ============================================================================= + + +# ── Data loading ────────────────────────────────────────────────────────────── + +mutable struct GeneExpressionData + cellLabels::Vector{String} + geneNames::Vector{String} + geneExpressionMat::Matrix{Float64} + potRegMatmRNA::Matrix{Float64} + potRegs::Vector{String} + potRegsmRNA::Vector{String} + targGenes::Vector{String} + targGeneMat::Matrix{Float64} + tfaGenes::Vector{String} + tfaGeneMat::Matrix{Float64} + + function GeneExpressionData() + return new( + [], + [], + Matrix{Float64}(undef, 0, 0), + Matrix{Float64}(undef, 0, 0), + [], + [], + [], + Matrix{Float64}(undef, 0, 0), + [], + Matrix{Float64}(undef, 0, 0) + ) + end +end + + +# ── Prior / TF merging ──────────────────────────────────────────────────────── + +mutable struct mergedTFsResult + mergedPrior::Union{DataFrame, Nothing} + mergedTFMap::Union{Matrix{String}, Nothing} + + function mergedTFsResult() + return new(nothing, nothing) + end +end + + +# ── TFA ─────────────────────────────────────────────────────────────────────── + +mutable struct PriorTFAData + pRegs::Vector{String} + pTargs::Vector{String} + priorMatrix::Matrix{Float64} + pRegsNoTfa::Vector{String} + pTargsNoTfa::Vector{String} + priorMatrixNoTfa::Matrix{Float64} + noPriorRegs::Vector{String} + noPriorRegsMat::Matrix{Float64} + targExpression::Matrix{Float64} + medTfas::Matrix{Float64} + + function PriorTFAData() + return new( + [], + [], + Matrix{Float64}(undef, 0, 0), + [], + [], + Matrix{Float64}(undef, 0, 0), + [], + Matrix{Float64}(undef, 0, 0), + Matrix{Float64}(undef, 0, 0), + Matrix{Float64}(undef, 0, 0) + ) + end +end + + +# ── GRN ─────────────────────────────────────────────────────────────────────── + +mutable struct GrnData + predictorMat::Matrix{Float64} + penaltyMat::Matrix{Float64} + allPredictors::Vector{String} + subsamps::Matrix{Int64} + responseMat::Matrix{Float64} + maxLambdaNet::Float64 + minLambdaNet::Float64 + minLambdas::Matrix{Float64} + maxLambdas::Matrix{Float64} + netInstabilitiesUb::Vector{Float64} + netInstabilitiesLb::Vector{Float64} + instabilitiesUb::Matrix{Float64} + instabilitiesLb::Matrix{Float64} + netInstabilities::Vector{Float64} + geneInstabilities::Matrix{Float64} + lambdaRange::Vector{Float64} + lambdaRangeGene::Vector{Vector{Float64}} + stabilityMat::Array{Float64} + priorMatProcessed::Matrix{Float64} + betas::Array{Float64,3} + + function GrnData() + return new( + Matrix{Float64}(undef, 0, 0), # predictorMat + Matrix{Float64}(undef, 0, 0), # penaltyMat + [], # allPredictors + Matrix{Int64}(undef, 0, 0), # subsamps (correctly Int64) + Matrix{Float64}(undef, 0, 0), # responseMat + 0.0, # maxLambdaNet + 0.0, # minLambdaNet + Matrix{Float64}(undef, 0, 0), # minLambdas + Matrix{Float64}(undef, 0, 0), # maxLambdas + [], # netInstabilitiesUb + [], # netInstabilitiesLb + Matrix{Float64}(undef, 0, 0), # instabilitiesUb + Matrix{Float64}(undef, 0, 0), # instabilitiesLb + [], # netInstabilities + Matrix{Float64}(undef, 0, 0), # geneInstabilities + [], # lambdaRange + Vector{Vector{Float64}}(undef, 0), # lambdaRangeGene + Array{Float64}(undef, 0, 0, 0), # stabilityMat (3-D) + Matrix{Float64}(undef, 0, 0), # priorMatProcessed + Array{Float64,3}(undef, 0, 0, 0) # betas + ) + end +end + + +mutable struct BuildGrn + networkStability::Matrix{Float64} + lambda::Union{Float64, Vector{Float64}} + targs::Vector{String} + regs::Vector{String} + rankings::Vector{Float64} + signedQuantile::Vector{Float64} + partialCorrelation::Vector{Float64} + inPrior::Vector{String} + networkMat::Matrix{Any} + networkMatSubset::Matrix{Any} + inPriorVec::Vector{Float64} + betas::Matrix{Float64} + mergeTfLocVec::Vector{Float64} + + function BuildGrn() + return new( + Matrix{Float64}(undef, 0, 0), # networkStability + 0.0, # lambda + [], # targs + [], # regs + [], # rankings + [], # signedQuantile + [], # partialCorrelation + [], # inPrior + Matrix{Float64}(undef, 0, 0), # networkMat + Matrix{Float64}(undef, 0, 0), # networkMatSubset + [], # inPriorVec + Matrix{Float64}(undef, 0, 0), # betas + Float64[] # mergeTfLocVec + ) + end +end diff --git a/src/Utils/dataUtils.jl b/src/Utils/dataUtils.jl new file mode 100755 index 0000000..5ad07d5 --- /dev/null +++ b/src/Utils/dataUtils.jl @@ -0,0 +1,326 @@ +using DataFrames +using CSV +using LinearAlgebra +using TickTock +using FileIO + +# Convert Wide/long data to long/wide +function convertToLong(data) +#dfs = [convertToLong(df) for df in dfs] + if ncol(data) > 3 + return stack(data, Not(1) ) + else + return data + end +end + + """ + Converts a 3‑column long-format DataFrame to wide‑format using the unstack function. + If the input DataFrame has exactly 3 columns, the conversion is performed. + + - Data columns is a wide matrix with columns as TF and rows as target genes or + a long data with columns in the order TF, Gene, Weights. + indices::Union{Nothing, NTuple{3, Int}} = nothing: + A tuple specifying the column indices in the order (pivot, key, value). + - pivot: The column that provides the row identifier. + - key: The column whose unique values will become new column names. + - value: The column from which the cell values are taken. + + If no indices are provided, the function defaults to (1, 2, 3). + If the DataFrame has more than 3 columns, the original DataFrame is returned. + If the DataFrame has less than 3 columns, an error is thrown. + + # USAGE + dfs = [convertToWide(df) for df in dfs] + dfs = [convertToWide(df; indices = (1,2,3)) for df in dfs] + """ +function convertToWide(Data; indices::Union{Nothing, NTuple{3, Int}}=nothing) + ncols = ncol(Data) + + if ncols < 3 + error("DataFrame has less than 3 columns. A 3‑column DataFrame or a wide-matrix is required.") + elseif ncols > 3 + # More than 3 columns: return data unchanged. + return Data + else # Exactly 3 columns + # Use provided indices, or default to (2, 1, 3) + inds = isnothing(indices) ? (2,1,3) : indices + # Convert the specified columns to Symbols for unstack. + idSym = Symbol(names(Data)[inds[1]]) + keySym = Symbol(names(Data)[inds[2]]) + valueSym = Symbol(names(Data)[inds[3]]) + + return unstack(Data, idSym, keySym, valueSym) + end +end + + + + +""" +frobeniusNormalize(df::DataFrame; dims::Symbol = :row) + +Normalize a DataFrame based on the Frobenius (L2) norm. + +• If dims = :row (default), it normalizes each row (ignoring the first column). +• If dims = :column, it normalizes each column (ignoring the first column). + +The first column is assumed to contain non‐numeric data (e.g., row identifiers) +and is left unchanged. +""" +function frobeniusNormalize(df::DataFrame, dims::Symbol =:column) + dfNorm = deepcopy(df) + dfMat = Matrix(dfNorm[!, 2:end]) # Extract numerical part + dfMat = convert(Matrix{Float64}, dfMat) # Ensure it's Float64 + + if dims == :row + # Normalize each row. + for i in 1:size(dfMat, 1) + nrm = norm(dfMat[i, :], 2) # Compute the L2 norm for the row. + if nrm != 0 + dfMat[i, :] ./= nrm + end + end + elseif dims == :column + # Normalize each column. + for j in 1:size(dfMat, 2) + nrm = norm(dfMat[:, j], 2) # Compute the L2 norm for the column. + if nrm != 0 + dfMat[:, j] ./= nrm + end + end + else + throw(ArgumentError("dims must be :row or :column")) + end + + # Update the DataFrame with the normalized numeric values. + dfNorm[!, 2:end] .= dfMat + return dfNorm +end + + +""" +completeDF(df::DataFrame, id::Symbol, idsALL, allCols) + +Aligns a DataFrame to a common set of row identifiers and columns. + +### Arguments: +- `df::DataFrame`: The input DataFrame to align. +- `id::Symbol`: The identifier column (e.g., gene or sample ID). +- `idsALL`: A vector of all unique row identifiers across multiple DataFrames. +- `allCols`: A vector of all unique column names across multiple DataFrames. + +### Returns: +- A new DataFrame with: +- Rows corresponding to `idsALL` (missing rows filled with `missing`). +- Columns matching `allCols` (missing columns filled with `missing`). +- The original values retained where available. + +""" +function completeDF(df::DataFrame, id::Symbol, idsALL, allCols) + dfNew = DataFrame() + dfNew[!, id] = idsALL # Ensure all IDs are included + + for col in allCols + if col in Symbol.(names(df)) + mapping = Dict(row[id] => row[col] for row in eachrow(df)) # Store existing values + dfNew[!, col] = [get(mapping, rid, missing) for rid in idsALL] # Align data + else + dfNew[!, col] = fill(missing, length(idsALL)) # Fill missing columns + end + end + return dfNew +end + +""" +"Merges a vector of DataFrames by summing or averaging them cell-wise. + +Arguments: +- dfs::Vector{DataFrame}: A vector of DataFrames to merge. +- id::Symbol: The identifier column (common to every DataFrame). +- option::String: The operation to perform: either "sum" or "avg". + +Returns: +A DataFrame where the first column is the identifier and the remaining columns are the cell-wise +sum or average (depending on the option) of the numeric values from all data frames (with columns sorted in alphabetical order). +""" +function mergeDFs(dfs::Vector{DataFrame}, id::Symbol = :Gene, option::String = "sum") + + # 1. Compute the full set of row identifiers + idsALL = reduce(union, [unique(df[!, id]) for df in dfs]) + sort!(idsALL) + + # 2. Compute the full set of numeric columns + allCols = reduce(union, [setdiff(names(df), [string(id)]) for df in dfs]) + allCols = sort(Symbol.(allCols)) + + # 3. Complete all DataFrames to have the same structure + completed_dfs = [completeDF(df, id, idsALL, allCols) for df in dfs] + + # 4. Initialize matrices for sum and count tracking + nRows, nCols = length(idsALL), length(allCols) + sumMat = zeros(nRows, nCols) + countMat = zeros(Int, nRows, nCols) + + # 5. Compute sum and count matrices + for df in completed_dfs + mat = Matrix(df[!, allCols]) + mat = coalesce.(mat, 0.0) + sumMat .+= mat + countMat .+= (mat .!= 0) # Count nonzero interactions + end + + # 6. Compute sum or average + resultMat = option == "avg" ? sumMat ./ max.(countMat, 1) : sumMat + + # 7. Create the merged DataFrame + merged = DataFrame(id => idsALL) + for (j, col) in enumerate(allCols) + merged[!, col] = resultMat[:, j] + end + + return merged +end + +""" +Checks if the columns are of length 1 (L2 Norm). Returns the number of true and false cases +Checks if each numeric column (except the first) is L2 normalized (norm approx 1). +Returns a Dict of column name => Bool, and counts of true and false cases. +""" +function check_column_norms(df::DataFrame; atol=1e-8) + details = Dict{String,Bool}() + count_true = 0 + count_false = 0 + + for col in names(df)[2:end] + # Check if the column is numeric by testing its element type. + if !(eltype(df[!, col]) <: Number) + println("Column: ", col, " is non-numeric, skipping...") + continue + end + # Convert the column values to Float64. + values = Float64.(df[!, col]) + # Compute the L₂ norm. + col_norm = norm(values, 2) + # Check whether the norm is approximately 1. + is_norm_one = isapprox(col_norm, 1.0; atol=atol) + + details[string(col)] = is_norm_one + if is_norm_one + count_true += 1 + else + count_false += 1 + end + # println("Column: ", col, " Norm: ", col_norm, " ~ 1? ", is_norm_one) + end + + println("Total numeric columns normalized (True): ", count_true) + println("Total numeric columns not normalized (False): ", count_false) + return details, count_true, count_false +end + +function writeTSVWithEmptyFirstHeader(df::DataFrame, filepath::String; delim::Char='\t') + open(filepath, "w") do io + # Create custom header: + # Make the first header cell an empty string. + # The remaining header cells are the remaining column names. + hdr = [""; names(df)[2:end]] + # Write the header row using the delimiter. + println(io, join(hdr, delim)) + + # Write each row. Each row is converted into strings. + for row in eachrow(df) + # Convert every element of the row to a string. + row_values = [string(x) for x in collect(row)] + println(io, join(row_values, delim)) + end + end +end + +function binarizeNumeric!(data) + if isa(data, DataFrame) + for col in names(data) + if eltype(data[!, col]) <: Number + data[!, col] .= Int.(data[!, col] .!= 0) + end + end + return data + elseif isa(data, AbstractMatrix) + data .= Int.(data .!= 0) + return data + else + throw(ArgumentError("Input must be DataFrame or matrix")) + end +end + + +# function binarizeNonZero!(data) +# if isa(data, DataFrame) +# for c in names(data) +# if eltype(data[!, c]) <: Number +# data[!, c] .= ifelse.(data[!, c] .!= 0, 1, 0) +# end +# end +# elseif isa(data, AbstractMatrix) +# data .= ifelse.(data .!= 0, 1, 0) +# else +# throw(ArgumentError("Input must be a DataFrame or a Matrix")) +# end +# return data +# end + + +#= +# Test workflow with small data + +Example DataFrames with row names (represented as the first column in the DataFrame) +df1 = DataFrame(RowName = ["W", "X"], A = [1, 2], B = [3, 4], C = [5, 6]) +df2 = DataFrame(Gene = ["Y", "X"], B = [7, 8], C = [9, 10], D = [11, 12]) +df3 = DataFrame(Blue = ["P", "Q"], A = [13, 14], D = [15, 16], E = [17,18]) + +# # # # List of all DataFrames you want to combine +dfs = [df1, df2, df3] +dfs = [convertToLong(df) for df in dfs] +dfs = [convertToWide(df; indices = (1,2,3)) for df in dfs] + +tick() +dfNorm = [frobeniusNormalize1(df, :row) for df in dfs] +tock() + +commonID = :Gene +dfNorm = [rename!(df, names(df)[1] => commonID) for df in dfNorm] + +tick() +mergeDFs(dfNorm, :Gene, "sum") +tock() +=# + + +#Compute Frobenius norm for each row (excluding the first column) +# norm_dfs = [DataFrame(A = df[!, 1], FrobeniusNorm = [norm(row[2:end], 2) for row in eachrow(df)]) for df in dfs] # row +# norm_dfs = [DataFrame(Column = names(df)[2:end], FrobeniusNorm = [ norm(col, 2) for col in eachcol(df[!, 2:end]) ]) for df in dfs] # col + + +# df1 = "/data/miraldiNB/Michael/projects/GRN/hCD4T_Katko/dataBank/Priors/SCENICp/majorCellType/tf2gene_only_TF2G_wide_binary.tsv" +# df1 = CSV.read(df1, DataFrame; delim = "\t") +# df1Norm = frobeniusNormalize(df1, :column) +# check_column_norms(df1Norm; atol=1e-3) + +# df2 = "/data/miraldiNB/Michael/projects/GRN/hCD4T_Katko/dataBank/Priors/MotifScan5kbTSS_b.tsv" +# df2 = CSV.read(df2, DataFrame; delim = "\t") +# check_column_norms(df2; atol=1e-8) + +# dfs = [df1, df2] +# dfs = [CSV.read(df1, DataFrame; delim = "\t") for df in dfpath] +# merged = mergeDFs(dfs, :Column1, "sum") +# mergedNorm = frobeniusNormalize(merged, :column) +# check_column_norms(mergedNorm; atol=1e-8) +# maximum(Matrix(merged_b[:, 2:end])) + +# merged_b = binarizeNonZero!(merged) +# # order = sortperm(merged_b[:, "Column1"], rev = false) +# # merged_b = merged_b[order, :] +# writeTSVWithEmptyFirstHeader(merged_b, "/data/miraldiNB/Michael/hCD4T_Katko/dataBank/Priors/tf2gene_5kbTSS_b.tsv"; delim='\t') + + + diff --git a/src/Utils/networkIO.jl b/src/Utils/networkIO.jl new file mode 100755 index 0000000..8c6d6cb --- /dev/null +++ b/src/Utils/networkIO.jl @@ -0,0 +1,28 @@ +using JLD2 +using DelimitedFiles +using Printf +using Dates + +# Save all core structs +function saveData(expressionData, tfaData, grnData, buildGrn, outputDir::String, fileName::String) + output = joinpath(outputDir, fileName) + @save output expressionData tfaData grnData buildGrn +end + +# Write network tables +function writeNetworkTable!(buildGrn; outputDir::String, networkName::Union{String, Nothing}=nothing) + baseName = (networkName === nothing || isempty(networkName)) ? "edges" : networkName * "edges" + + outputFile = joinpath(outputDir, baseName * ".tsv") + colNames = "TF\tGene\tsignedQuantile\tStability\tCorrelation\tinPrior\n" + open(outputFile, "w") do io + write(io, colNames) + writedlm(io, buildGrn.networkMat) + end + + outputFileSubset = joinpath(outputDir, baseName * "_subset.tsv") + open(outputFileSubset, "w") do io + write(io, colNames) + writedlm(io, buildGrn.networkMatSubset) + end +end diff --git a/src/Utils/partialCorrelation.jl b/src/Utils/partialCorrelation.jl new file mode 100755 index 0000000..be8c2a6 --- /dev/null +++ b/src/Utils/partialCorrelation.jl @@ -0,0 +1,96 @@ + +# Method 1: Matrix Inversion Method: +# this methods computes partial correlation from precision-matrix (the inverse of variance-covariance matrix) +using LinearAlgebra +function partialCorrelationMat(X::Matrix{Float64}; epsilon = 1e-7, first_vs_all = false) + # Mean centering of the columns (mean subtraction) + X_centered = X .- mean(X, dims=1) + + # Compute the covariance matrix + sigma = cov(X_centered) + # regularizing the covariance matrix to avoid ill-conditining and singularity + sigma = sigma + epsilon * I + + # Precision matrix (inverse of the covariance matrix) + theta = inv(sigma) + p = size(X, 2) # number of features/predictors/explanatory variables + + if first_vs_all + # Compute partial correlation for the first variable vs all others + P = ones(1, p) # Initialize the partial correlation matrix + for j in 2:p + P[j] = -theta[1, j] / sqrt(theta[1, 1] * theta[j, j]) + end + else + # Compute the full partial correlation matrix + P = ones(p, p) # Initialize the partial correlation matrix + for i in 1:(p-1) + for j in (i+1):p + P[i , j] = -theta[i, j] / sqrt(theta[i, i] * theta[j, j]) + P[j, i] = P[i, j] # Symmetry + end + end + end + + return P +end + + +# Method 2: Using Regression +using GLM, Statistics, DataFrames + +function partialCorrReg(X::Matrix{Float64}; first_vs_all = false) + p = size(X, 2) + P = ones(p, p) + + # Convert the matrix to a DataFrame for easier variable handling + df = DataFrame(X, :auto); + col_names = names(df) + + if first_vs_all + P = ones(1, p) + for j in 2:p + keep_indices = setdiff(2:p, j) + covariates = col_names[keep_indices] # All columns except i and j + + # First, regress the first covariate on all other features except + model_1 = lm(term.(col_names[1]) ~ sum(term.(covariates)), df) + res_1 = residuals(model_1) # Get residuals for 1 + + # Then, regress j on all others except i + model_j = lm(term.(col_names[j]) ~ sum(term.(covariates)), df) + res_j = residuals(model_j) # Get residuals for j + + # Compute c1rrelation between the residuals + r = cor(res_1, res_j) + P[j] = r + end + else + P = ones(p, p) + for i in 1:p + for j in (i+1):p + + keep_indices = setdiff(1:p, [i ,j]) + covariates = col_names[keep_indices] # All columns except i and j + + # reference: https://discourse.julialang.org/t/using-all-independent-variables-with-formula-in-a-multiple-linear-model/43691/4 + + # First, regress i on all other features except j (i is the response) + model_i = lm(term.(col_names[i]) ~ sum(term.(covariates)), df) + res_i = residuals(model_i) # Get residuals for i + + # Then, regress j on all others except i + model_j = lm(term.(col_names[j]) ~ sum(term.(covariates)), df) + res_j = residuals(model_j) # Get residuals for j + + # Compute correlation between the residuals + r = cor(res_i, res_j) + P[i, j] = r + P[j, i] = r # Symmetric matrix + end + end + end + return P +end + + diff --git a/src/data/GeneExpressionData.jl b/src/data/GeneExpressionData.jl new file mode 100755 index 0000000..01ef5dd --- /dev/null +++ b/src/data/GeneExpressionData.jl @@ -0,0 +1,205 @@ +# 00_Data/geneExpression.jl + + using DelimitedFiles + using Statistics + using JLD2 + using CSV + using Arrow + using DataFrames + + # Struct defined in src/Types.jl + + + # Load Expression Data + """ + loadExpressionData!(data::GeneExpressionData, geneExprFile) + + Loads gene expression data from a specified file and updates the `GeneExpressionData` object. + + # Arguments + - `data::GeneExpressionData`: The data object to be populated with gene expression details. + - `geneExprFile::String`: The path to the gene expression data file, which can be in `.arrow` format or tab-delimited text format. + + # Updates + - `data.cellLabels`: A vector of sample condition names, derived from the column headers (excluding the first). + - `data.targGenes`: A vector of gene names, derived from the first column. + - `data.targGeneMat`: A matrix of expression values, converted to `Float64`. + + # Raises + - An error if the gene expression file does not exist or if its path is invalid. + + # Notes + - Assumes that the first column in the file contains gene names and subsequent columns contain cell labels and expression data. + - For text files, data is sorted by gene names. + """ + function loadExpressionData!(data::GeneExpressionData, geneExprFile) + if isfile(geneExprFile) + if endswith(geneExprFile, ".arrow") + dfArrow = Arrow.Table(geneExprFile) + df = deepcopy(DataFrame(dfArrow)) + dfArrow = nothing; + cellLabels = names(df)[2:end] # Assume first column is gene names + geneNames = df[:, 1] # Extract gene names + ncounts = Matrix(df[:, 2:end]) + df = nothing + else + fid = open(geneExprFile); + C = readdlm(fid, '\t', '\n'); + close(fid) + cellLabels = C[1, :]; + cellLabels = filter(!isempty, cellLabels); + C = C[2:end, :]; + inds = sortperm(C[:, 1]); + C = C[inds, :]; + geneNames = C[:, 1]; + ncounts = C[:, 2:end] + end + ncounts = convert(Matrix{Float64}, ncounts); + + data.cellLabels = String.(cellLabels); + data.geneNames = String.(geneNames); + data.geneExpressionMat = ncounts; + ncounts = nothing; + else + error("Expression data file path is invalid.") + end + end + + + """ + loadAndFilterTargetGenes!(data::GeneExpressionData, targetGeneFile; eps=1E-10) + + Loads target genes from a file, filters them based on presence and variance in the existing gene expression data, and updates the `GeneExpressionData` object. + + # Arguments + - `data::GeneExpressionData`: The data object to be updated with filtered target gene information. + - `targetGeneFile::String`: The path to the file containing target gene names. + - `eps::Float64`: A keyword argument specifying the variance threshold for filtering target genes. Default is `1E-10`. + + # Updates + - `data.targGenes`: A filtered vector of target gene names that are present in the gene expression data and meet the variance threshold. + - `data.targGeneMat`: A matrix of expression values for these filtered target genes. + + # Raises + - An error if the target gene file does not exist. + - An error if no target genes are found in the gene expression data. + - An error if all target genes are filtered out due to low variance. + + # Notes + - This function finds matching genes and filters them based on a variance threshold to ensure sufficient expression variability across samples. + """ + function loadAndFilterTargetGenes!(data::GeneExpressionData, targetGeneFile; epsilon=1E-10) + if isfile(targetGeneFile) + # Load in target gene file + fid = open(targetGeneFile) + targetGenes = readdlm(fid, String) + close(fid) + + # Find all geneNames that are in the targer gene file + inds = findall(in(targetGenes), data.geneNames) + if isempty(inds) + error("No target genes found in expression data!") + end + + # Filter target genes and expression matrix + targGenes = data.geneNames[inds] + targGeneMat = data.geneExpressionMat[inds, :] + + # Filter genes by minimum variance cutoff + stds = std(targGeneMat, dims=2) + keep = [index[1] for index in findall(stds .>= epsilon)] + targGenesFilter = targGenes[keep] + targGeneMatFilter = targGeneMat[keep, :] + if isempty(keep) + error("All target genes removed due to low variance!") + else + println(length(targGenesFilter), " target genes retained after filtering") + end + data.targGenes = targGenesFilter + data.targGeneMat = targGeneMatFilter + else + error("Target gene file not found.") + end + end + + # Load Regulators + """ + loadPotentialRegulators!(data::GeneExpressionData, potRegFile) + + Loads potential regulators and updates the `GeneExpressionData` object with their expression data. + + # Arguments + - `data::GeneExpressionData`: The data object to be populated with potential regulator information. + - `potRegFile::String`: The path to the file containing potential regulator names. + + # Updates + - `data.potRegs`: List of all potential regulator names. + - `data.potRegsmRNA`: List of regulator names with corresponding mRNA expression data. + - `data.potRegMatmRNA`: Expression matrix for regulators with mRNA data. + + # Raises + - An error if the potential regulators file does not exist. + + # Notes + - Only the potential regulators present in the existing gene expression data are included. + """ + function loadPotentialRegulators!(data::GeneExpressionData, potRegFile) + if isfile(potRegFile) + fid = open(potRegFile) + potentialRegs = readdlm(fid, String) + close(fid) + potentialRegs = vec(potentialRegs) + + indsPotRegs = findall(in(data.geneNames), potentialRegs) + potentialRegs = potentialRegs[indsPotRegs] + + inds = findall(in(potentialRegs), data.geneNames) + potRegsmRNA = data.geneNames[inds] + potRegMatmRNA = data.geneExpressionMat[inds, :] + + data.potRegs = potentialRegs + data.potRegsmRNA = potRegsmRNA + data.potRegMatmRNA = potRegMatmRNA + else + error("Potential regulators file not found.") + end + end + + + + # Process genes for TFA calculation + """ + processTFAGenes(file, geneSc, nCounts) + + Processes TFA genes by retrieving their expression data from the expression matrix. + + # Arguments + - `file::String`: The path to the TFA genes file. If the path is empty or invalid, all genes are used. + - `geneNames::Vector{String}`: A vector of gene names from the expression data. + - `geneExpressionMat::Matrix{Float64}`: A matrix containing expression values for the genes. + + # Returns + - `tfaGenes::Vector{String}`: A vector of TFA gene names with expression data. + - `tfaGeneMat::Matrix{Float64}`: A matrix of expression values for the TFA genes. + + # Notes + - If the specified file does not exist, all genes are considered for TFA. + """ + function processTFAGenes!(data::GeneExpressionData, tfaGeneFile::Union{String, Nothing}; outputDir::Union{String, Nothing}=nothing) + if (tfaGeneFile !== nothing) && (tfaGeneFile != "") && isfile(tfaGeneFile) + tfaGenes = readlines(tfaGeneFile) + else + tfaGenes = data.geneNames + end + + inds = findall(in(tfaGenes), data.geneNames) + data.tfaGenes = data.geneNames[inds] + data.tfaGeneMat = data.geneExpressionMat[inds, :] + + if outputDir !== nothing && outputDir !== "" + + outputFile = joinpath(outputDir, "geneExprMat.jld") + save_object(outputFile, data) + + end + end diff --git a/src/data/PriorTFAData.jl b/src/data/PriorTFAData.jl new file mode 100755 index 0000000..51f46c5 --- /dev/null +++ b/src/data/PriorTFAData.jl @@ -0,0 +1,183 @@ + # Standard libraries + using LinearAlgebra + using Statistics + using DelimitedFiles + using JLD2 + + # Struct defined in src/Types.jl + + """ + processPriorFile!(priorData::PriorTFAData, priorFile) + + Reads and processes a prior file, storing the extracted information in a `PriorTFAData` object. + + # Arguments + - `priorData::PriorTFAData`: The struct to be populated with data from the prior file. + - `priorFile::String`: The path to the prior file, formatted with tab-separated values. + + # Updates + - `priorData.pRegs`: Stores the list of transcription factors (TFs) from the prior file. + - `priorData.pTargs`: Stores the list of target genes from the prior file. + - `priorData.priorMatrix`: Stores the interactions matrix, indicating TF-gene relationships. + + # Raises + - An error if the prior file does not exist. + + # Notes + - Assumes the first row in the file contains TF names, and subsequent rows represent interactions with target genes. + - The gene list and matrix are sorted alphabetically by gene names. + """ + function processPriorFile!(priorData::PriorTFAData, + expressionData::GeneExpressionData, + priorFile; mergedTFsData::Union{mergedTFsResult, Nothing}=nothing, minTargets = 3) + + if isfile(priorFile) + + println("--- Case 1") + fid = open(priorFile) + C = readdlm(fid, '\t', '\n', skipstart=0) + close(fid) + + # Process and store genes and interactions matrix + pRegsTmp = convert(Vector{String}, filter(!isempty, C[1, :])) + C = C[2:end, :] + inds = sortperm(C[:, 1]) + C = C[inds, :] + pTargsTmp = convert(Vector{String}, C[:, 1]) + pMatrixTmp = convert(Matrix{Float64}, C[:, 2:end]) + + # Filter to only include those in potential regulator and target gene list + targInds = findall(in(expressionData.tfaGenes), pTargsTmp) + regInds = findall(in(expressionData.potRegs), pRegsTmp) + pRegsNoTfa = pRegsTmp[regInds] + pTargsNoTfa = pTargsTmp[targInds] + priorMatrixNoTfa = pMatrixTmp[targInds, regInds] + + # Find TFs that have expression data but arent in the prior + noPriorRegs = setdiff(expressionData.potRegsmRNA, pRegsNoTfa) + expInds = findall(in(noPriorRegs), expressionData.potRegsmRNA) + noPriorRegsMat = expressionData.potRegMatmRNA[expInds,:] + + println("--- Case 2") + ## Case 2: prior-based TFA is used + # Check whether there were degenerate TFs and outputs + if (mergedTFsData !== nothing) && + (mergedTFsData.mergedPrior !== nothing) && + (mergedTFsData.mergedTFMap !== nothing) + + println("-------- Using merge degenerate TFs prior file") + mergedTFs = mergedTFsData.mergedTFMap[:, 1] + individualTFs = mergedTFsData.mergedTFMap[:, 2] + + totMergedSets = length(individualTFs) + keepMergedIndices = Int[] + + for idx in 1:totMergedSets + currSet = split(individualTFs[idx], ", ") + usedTfs = intersect(currSet, expressionData.potRegs) + if !isempty(usedTfs) + push!(keepMergedIndices, idx) + end + end + + if !isempty(keepMergedIndices) # add merged potential regulators to our list + expressionData.potRegs = union(expressionData.potRegs, mergedTFs[keepMergedIndices]) + # Now load the merged prior matrix data (mergedPrior) + priorDF = mergedTFsData.mergedPrior + rowInd = sortperm(priorDF[:, 1]) + pRegsTmp = names(priorDF)[2:end] + totPRegs = length(pRegsTmp) + pTargsTmp = priorDF[rowInd, 1] + pMatrixTmp = Matrix(priorDF[rowInd, 2:end]) + end + end + + ### Filter TFs and Genes for TFA calculation + pTargInds = findall(in(expressionData.tfaGenes), pTargsTmp) + pRegInds = findall(in(expressionData.potRegs), pRegsTmp) + pTargs = pTargsTmp[pTargInds] + pRegs = pRegsTmp[pRegInds] + pInts = pMatrixTmp[pTargInds,pRegInds] + # Filter for minimum target genes in prior + interactionsPerTF = sum((abs.(sign.(pInts))), dims=1) + keepRegs = Tuple.(findall(x -> x > minTargets, interactionsPerTF)) + keepRegs = last.(keepRegs) + pInts = pInts[:,keepRegs] + pRegs = pRegs[keepRegs] + # Filter genes with no regulators + interactionsPerTarg = sum(abs.(sign.(pInts)), dims = 2) + keepTargs = Tuple.(findall(x -> x > 0, interactionsPerTarg)) + keepTargs = first.(keepTargs) + pInts = pInts[keepTargs,:] + pTargs = pTargs[keepTargs] + + ### Ensure expression and prior targets are in the same order + expTargInds = findall(in(pTargs), expressionData.tfaGenes) + targExp = expressionData.tfaGeneMat[expTargInds,:] + if !(pTargs == expressionData.tfaGenes[expTargInds]) + println("Warnings, gene order in expression matrix and prior matrix do not match!!") + end + + # Add data to object + priorData.pRegs = pRegs + priorData.pTargs = pTargs + priorData.priorMatrix = pInts + priorData.pRegsNoTfa = pRegsNoTfa + priorData.pTargsNoTfa = pTargsNoTfa + priorData.priorMatrixNoTfa = priorMatrixNoTfa + priorData.noPriorRegs = noPriorRegs + priorData.noPriorRegsMat = noPriorRegsMat + priorData.targExpression = targExp + + else + error("Prior file not found.") + end + end + + function calculateTFA!(priorData::PriorTFAData, expressionData::GeneExpressionData; + edgeSS = 0, zTarget::Bool = false, outputDir::Union{String, Nothing}=nothing) + priorMatrix = priorData.priorMatrix + targExp = priorData.targExpression + totTargs = size(priorMatrix, 1) + totPreds = size(priorMatrix, 2) + totConds = size(targExp, 2) + + if zTarget + targExp = (targExp .- mean(targExp, dims=2)) ./ std(targExp, dims=2) + println("Target expression normalized (z-score per gene).") + end + + if edgeSS > 0 + tfas = zeros(Float64, edgeSS, totPreds, totConds) + + for ss in 1:edgeSS + sPrior = zeros(Float64, totTargs, totPreds) + + for col in 1:totPreds + currTargs = priorMatrix[:, col] + targInds = findall(!iszero, currTargs) + totCurrTargs = length(targInds) + ssampleSize = Int(ceil(0.63 * totCurrTargs)) + ssample = rand(targInds, ssampleSize) + + sPrior[ssample, col] = priorMatrix[ssample, col] + end + tfas[ss, :, :] = sPrior \ targExp + end + + priorData.medTfas = median(tfas, dims=1) + println("Median from ", string(edgeSS), " subsamples used for prior-based TFA.") + else + # solves for X = Prior * TFA. + # TFA = argmin||priorMatrix * TFA - targExp||² + priorData.medTfas = priorMatrix \ targExp + println("No subsampling for prior-based TFA estimate.") + end + + if outputDir !== nothing && outputDir !== "" + + outputFile = joinpath(outputDir, "tfaMat.jld") + save_object(outputFile, priorData) + + end + end diff --git a/src/grn/AggregateNetworks.jl b/src/grn/AggregateNetworks.jl new file mode 100755 index 0000000..38f4b5c --- /dev/null +++ b/src/grn/AggregateNetworks.jl @@ -0,0 +1,301 @@ +""" + GRN + +This module provides functionality for building and combining Gene Regulatory Networks (GRNs). + +The GRN files are assumed to have the following columns: + TF, Gene, signedQuantile, Stability, Correlation, strokeVals, strokeWidth, inPrior +Main function: +- `combineGRNs` (internal): combine multiple GRN files with options for mean/max aggregation, + controlling edges per gene, and saving the combined network. + +Dependencies: +- Uses `PriorUtils` for prior-related helper functions. +- Includes internal utilities in `utilsGRN.jl`. +""" + +# using ..DataUtils +# using CSV, DataFrames, Statistics, StatsBase, Printf, Dates +# using ArgParse + +# export combineGRNs +# ──────────────────────────────────────────────────────────────── +# Helper: deterministic tie-breaker +# ──────────────────────────────────────────────────────────────── +function pickRow(g, primary::Union{String,Symbol}; order::Symbol = :max) + + """ + pickRow(g, primary; order = :max) + + Return a single row index from SubDataFrame `g` according to a deterministic + tie-breaking hierarchy. + + Arguments + ───────── + g - SubDataFrame produced by `groupby` + primary - "Stability"/:Stability or "Quantile"/:Quantile + order - :max (default) or :min (optimise up or down) + + Tie-breaking + ──────────── + If primary = Stability : Stability → Quantile → |Correlation| + If primary = Quantile : Quantile → |Correlation| + """ + + prim = Symbol(primary) # normalise to Symbol + + # 1) extreme value of the primary column + pVals = g[!, prim] + extreme = order === :max ? maximum(pVals) : minimum(pVals) + idxs = findall(==(extreme), pVals) + length(idxs) == 1 && return idxs[1] + + # 2) additional tie-breakers + if prim == :Stability && "signedQuantile" in names(g) + qVals = g.signedQuantile[idxs] + extQ = order === :max ? maximum(qVals) : minimum(qVals) + idxs = idxs[findall(==(extQ), qVals)] + length(idxs) == 1 && return idxs[1] + # if still tied we fall through to correlation + end + + if "Correlation" in names(g) + absCorr = abs.(g.Correlation[idxs]) + bestC = maximum(absCorr) # always take largest |ρ| + idxs = idxs[findall(==(bestC), absCorr)] + end + + return idxs[1] # deterministic fallback +end + + +# ──────────────────────────────────────────────────────────────── +# Main function +# ──────────────────────────────────────────────────────────────── +function aggregateNetworks(nets2combine::Vector{String}; + method::Union{String,Symbol} = :max, + meanEdgesPerGene::Int = 20, + useMeanEdgesPerGene::Bool = true, + outputDir::Union{String,Nothing} = nothing, + saveName::String = "") + + """ + # Arguments + - `nets2combine::Vector{String}`: A vector of GRN file names to combine. + - `method::Union{String,Symbol}`: Combination strategy — :max, :min, :mean, :meanQuantile, :maxQuantile. + - `meanEdgesPerGene::Int`: Mean number of edges per target gene. + - `outputDir::Union{String,Nothing}`: Directory to save results. Defaults to `nothing`. + - `saveName::String`: Optional filename prefix for saved files. Defaults to `""`. + useMeanEdgesPerGene: If `true`, selects `meanEdgesPerGene * uniqueGenes` edges globally. + If `false`, selects top `meanEdgesPerGene` per target gene. + """ + combineOpt = string(method) # normalise Symbol/String → String for comparisons + useMeanEdgesPerGeneMode = useMeanEdgesPerGene + saveDir = outputDir + allDfs = DataFrame[] + + # ──── 1. Read each file and then combine; create a rank column if not present ─────────────────────────── + for file in nets2combine + # Adjust CSV reading as needed (TSV expected if tab sep) + df = CSV.read(file, DataFrame; delim='\t') + # println(size(df)) + + # # Ensure the required columns exist + # requiredCols = ["TF", "Gene", "signedQuantile", "Stability", "Correlation", "inPrior"] + # dfCols = names(df) + # for col in requiredCols + # if !(col in dfCols) + # error("Column $(col) not found in file $(file)") + # end + # end + # df = df[!, requiredCols] + + # ensure numeric columns are Float64 + for col in [:Stability, :Correlation] + if !(eltype(df[!, col]) <: AbstractFloat) + df[!, col] = parse.(Float64, string.(df[!, col])) + end + end + + # fill in ranks ONLY if these columns do not already exist + if !("signedQuantile" in names(df)) #|| !(:Ranks in names(df)) + st_ecdf = ecdf(df.Stability) + quantile = st_ecdf.(df.Stability) + df.signedQuantiles = quantile .* sign.(df.Correlation) + end + push!(allDfs, df) + end + + # ──── 2. CONCATENATE AND AGGREGATE ACCORDING TO combineOpt ─────────────────────────── + # Combine all dataframes vertically + combinedDf = vcat(allDfs...) + # Group by TF and Gene and aggregate according to the chosen method. + groupedDf = groupby(combinedDf, ["TF", "Gene"]) + aggRows = NamedTuple[] + + for g in groupedDf + # All rows in this group share the same TF, Gene and inPrior + tf, gene, dash = g.TF[1], g.Gene[1], g.inPrior[1] + # tf, gene, dash = g.TF[1], g.Gene[1], g.strokeDashArray[1] + + if combineOpt == "mean" + c = mean(g.Correlation) + push!(aggRows, (TF = tf, Gene = gene, Stability = mean(g.Stability), + Correlation = c, + inPrior = dash)) + + elseif combineOpt == "max" + # idx = combineOpt == "max" ? argmax(g.Stability) : argmin(g.Stability) + idx = pickRow(g, "Stability", order = :max) + row = g[idx, :] + push!(aggRows, (TF = tf, Gene = gene, signedQuantile = row.signedQuantile, Stability = row.Stability, + Correlation = row.Correlation, inPrior = row.inPrior)) + + elseif combineOpt == "min" + # idx = combineOpt == "max" ? argmax(g.Stability) : argmin(g.Stability) + idx = pickRow(g, "Stability", order = :min) + row = g[idx, :] + push!(aggRows, (TF = tf, Gene = gene, Stability = row.Stability, + Correlation = row.Correlation, inPrior = row.inPrior)) + + elseif combineOpt == "meanQuantile" + c = mean(g.Correlation) + push!(aggRows, (TF = tf, Gene = gene, signedQuantile = mean(g.signedQuantile), Stability = mean(g.Stability), + Correlation = c, + inPrior = dash)) + + elseif combineOpt == "maxQuantile" + # idx = argmax(g.signedQuantile) + idx = pickRow(g, "signedQuantile", order = :max) + row = g[idx, :] + push!(aggRows, (TF = tf, Gene = gene, signedQuantile = row.signedQuantile, Stability = row.Stability, Correlation = row.Correlation, + inPrior = row.inPrior)) + else + error("Invalid combineOpt: $(combineOpt).") + end + end + + aggregatedDf = DataFrame(aggRows) + + + # ────── 3. Filter out extrememly weak correlations and compute new signed quantiles using broadcasting ────────────────────── + pcut = 0.01 # Correlation cutoff + aggregatedDf = filter(:Correlation => x -> abs(x) > pcut, aggregatedDf) + # ── 3b. Compute signedQuantile ───────────────── + if combineOpt in ("max","min","mean") + stability = aggregatedDf.Stability # Vector + F = ecdf(stability) # F(x) = proportion ≤ x + signed_q = F.(stability) .* sign.(aggregatedDf.Correlation) + aggregatedDf[!,:signedQuantile] = signed_q + end + + # Checks: + # abs_q = abs.(aggregatedDf.signedQuantile) + # i_min = argmin(abs_q) + # row_min = aggregatedDf[i_min, :] + + # sort for reproducibility + sort!(aggregatedDf, :signedQuantile, by= abs, rev = true) + # sort!(aggregatedDf, combineOpt in ("meanQuantile", "maxQuantile") ? :signedQuantile : :Stability, + # by = combineOpt in ("meanQuantile", "maxQuantile") ? abs : identity, rev = true) + + # ────── 5. choose top meanEdgesPerGene * nGenes edges ────────────────────── + + if useMeanEdgesPerGeneMode + println("Selecting the top (meanEdgesPerGene * unique targets) edges") + # Select the top `topEdgeCount = meanEdgesPerGene * uniqueGenes` & compute quantiles for all rankings + uniqueGenes = length(unique(aggregatedDf.Gene)) # current number of unique targets + topEdgeCount = round(Int, meanEdgesPerGene * uniqueGenes) # Compute the number of edges to select + selectionIndices = 1:min(topEdgeCount, nrow(aggregatedDf)) + else + println("Selecting the top meanEdgesPerGene edges per target gene") + selectionIndices = firstNByGroup(aggregatedDf.Gene, meanEdgesPerGene) + end + selectedDf = aggregatedDf[selectionIndices, :] + println("Selected $(length(selectionIndices)) edges using meanEdgesPerGene = $meanEdgesPerGene.") + + # # ────── 6. Generate strokeWidth and colors + # # ────── Color calculations for JP-Gene-Viz + # minRank, maxRank = extrema(selectedDf.Stability) + # rankRange = max(maxRank - minRank, eps()) # prevent division by zero + # strokeWidth = 1 .+ (selectedDf.Stability .- minRank) ./ rankRange + # # Color mapping + # medRed = [228, 26, 28] + # lightGrey = [217, 217, 217] + # strokeVals = map(abs.(selectedDf.Correlation)) do corr + # color = corr * medRed .+ (1 - corr) * lightGrey + # "rgb($(floor(Int, round(color[1]))),$(floor(Int, round(color[2]))),$(floor(Int, round(color[3]))))" + # # "rgb(" * string(floor(Int, round(color[1]))) * "," * string(floor(Int, round(color[2]))) * "," * string(floor(Int, round(color[3]))) * ")" + # end + # selectedDf[!, :strokeVals] = strokeVals + # selectedDf[!, :strokeWidth] = strokeWidth + + + # ── 7. save to disk if saveDir is provided ────────────────────────────────────────── + if saveDir !== nothing + mkpath(saveDir) + tag = isnothing(saveName) || isempty(saveName) ? combineOpt : "$(saveName)_$(combineOpt)" + # CSV.write(joinpath(saveDir, "$(tag)_aggregated.tsv"), aggregatedDf; delim = '\t') + CSV.write(joinpath(saveDir, "combined_$(tag).tsv"), selectedDf; delim = '\t') + + wideDf = convertToWide(selectedDf[ :,["TF","Gene","signedQuantile"]]; indices=(2,1,3)) + wideDf = coalesce.(wideDf, 0) + wideFile = joinpath(saveDir, "combined_$(tag)_sp.tsv") + writeTSVWithEmptyFirstHeader(wideDf, wideFile; delim = '\t') + println("Files written under ", saveDir) + end + + return selectedDf +end + +# ----------------------- +# main procedure +# ----------------------- +# This section executes if the script is run directly. It will not execute if called/imported into another script. +if abspath(PROGRAM_FILE) == @__FILE__ + using ArgParse + + # Define argument parser + s = ArgParseSettings() + @add_arg_table s begin + "--combineOpt" + help = "Combination option: max, min, or mean (default: mean)" + arg_type = String + default = "mean" + "--meanEdgesPerGene" + help = "Mean number of edges per unique gene." + arg_type = Int + required = true + "--useMeanEdgesPerGeneMode" + help = "controls whether edge selection is done per-group or globally. If true, selects length(unique(targs)) * meanEdgesPerGene edges. + If `false`, selects selects the top meanEdgesPerGene edges per target gene." + arg_type = Bool + default = true + "--saveDir" + help = "Directory in which to save output files." + arg_type = String + default = "" + "--saveName" + help = "Base name for the saved file." + arg_type = String + default = "" + "files..." + help = "List of GRN files (TSV format) to combine." + end + + parsedArgs = parse_args(s) + fileList = parsedArgs["files"] + combineOpt = parsedArgs["combineOpt"] + meanEdgesPerGene = parsedArgs["meanEdgesPerGene"] + useMeanEdgesPerGeneMode = parsedArgs["useMeanEdgesPerGeneMode"] + saveDir = isempty(parsedArgs["saveDir"]) ? nothing : parsedArgs["saveDir"] + saveName = parsedArgs["saveName"] + + @printf("Combining %d files with combination option: %s\n", length(fileList), combineOpt) + + selectedDf = aggregateNetworks(fileList; method=combineOpt, meanEdgesPerGene=meanEdgesPerGene, outputDir=saveDir, saveName=saveName) + + # Optionally, print summary + @printf("Final network has %d edges spanning %d unique genes.\n", nrow(selectedDf), length(unique(selectedDf.Gene))) +end + diff --git a/src/grn/BuildGRN.jl b/src/grn/BuildGRN.jl new file mode 100755 index 0000000..3ed4753 --- /dev/null +++ b/src/grn/BuildGRN.jl @@ -0,0 +1,377 @@ +# mutable struct BuildGrn +# networkStability::Matrix{Float64} +# lambda::Union{Float64, Vector{Float64}} +# targs::Vector{String} +# regs::Vector{String} +# rankings::Vector{Float64} +# signedQuantile::Vector{Float64} +# partialCorrelation::Vector{Float64} +# inPrior::Vector{String} +# networkMat::Matrix{Any} +# networkMatSubset::Matrix{Any} +# inPriorVec::Vector{Float64} +# betas::Matrix{Float64} +# function BuildGrn() +# return new( +# Matrix{Float64}(undef, 0, 0), # networkStability +# 0.0, # lambda +# [], # targs +# [], # regs +# [], # rankings +# [], # signedQuantile +# [], # partialCorrelation +# [], # inPrior +# Matrix{Float64}(undef, 0, 0), # networkMat +# Matrix{Float64}(undef, 0, 0), # networkMatSubset +# [], # inPriorVec +# Matrix{Float64}(undef, 0, 0) # betas +# ) +# end +# end + +function chooseLambda!(grnData::GrnData, buildGrn::BuildGrn; instabilityLevel = "Gene", targetInstability = 0.05) + totLambdas, totNetGenes, totNetTfs = size(grnData.stabilityMat) + networkStability = zeros(totNetGenes, totNetTfs) + networkStability = convert(Matrix{Float64}, networkStability) + # lambdaRange = reverse(grnData.lambdaRange) + lambdaRange = grnData.lambdaRange + betas = zeros(totNetGenes, totNetTfs) + lambdaTrack = [] + ## transform StARS instabilities into stabilities + if instabilityLevel == "Gene" + totMins = 0 + totMaxs = 0 + for targ = 1:totNetGenes + currInstabs = grnData.geneInstabilities[targ, :] + devs = abs.(currInstabs .- targetInstability) + globalMin = minimum(devs) + minInds = findall(x -> x == globalMin[1], devs) # globalMin + minInd = minInds[end] + + lambdaRangeGene = grnData.lambdaRangeGene[targ] + push!(lambdaTrack, lambdaRangeGene[minInd]) + + networkStability[targ, :] = grnData.stabilityMat[minInd, targ, :] + betas[targ, :] = grnData.betas[targ, :, minInd] + + if minInd == 1 + totMins += 1 + elseif minInd == totLambdas + totMaxs += 1 + end + end + if totMins > 0 + println("Target instability reached at minimum lambda for ", string(totMins), " gene(s)") + end + if totMaxs > 0 + println("Target instability reached at maximum lambda for ", string(totMaxs), " gene(s)") + end + + elseif instabilityLevel == "Network" + println("Network instabilities detected.") + # find the single lambda corresponding to the cutoff + devs = abs.(grnData.netInstabilities .- targetInstability) + globalMin = findmin(devs) + minInds = first(Tuple(globalMin[2])) + globalMin = globalMin[1] + # take the largest lambda that is closest to targetInstability + minInd = minInds[end] + if minInd == 1 + println("Minimum lambda was used for maximum instability ", string(grnData.netInstabilities[minInd]), " to reach target cut = ", string(targetInstability), ".") + elseif minInd == totLambdas + println("Maximum lambda was used for minimum instability ", string(grnData.netInstabilities[minInd]), " to reach target cut = ", string(targetInstability), ".") + end + networkStability[:, :] = grnData.stabilityMat[minInd, :, :] + betas = grnData.betas[:, :, minInd] + lambdaTrack = lambdaRange[minInds] + + else + error("instabSource not recognized, should be either Gene or Network.") + end + + # ssMatrix has infinity entries to mark illegal TF-gene interactions + # (e.g., TF mRNA TFA cannot be used to predict TF gene expression) + networkStabilityVec = networkStability[:] + networkStabilityVec[isinf.(networkStabilityVec)] .= 0 + networkStability = reshape(networkStabilityVec, totNetGenes, totNetTfs) + + if length(lambdaTrack) == 1 + lambdaTrack = convert(Float64, lambdaTrack) + else + lambdaTrack = convert(Vector{Float64}, lambdaTrack) + end + + buildGrn.lambda = lambdaTrack + buildGrn.networkStability = networkStability + buildGrn.betas = betas +end + +# function firstNByGroup(vect::AbstractVector, N::Integer) +# """ +# firstNByGroupIndices(vec::AbstractVector, N::Integer) + +# Return the indices of the first `N` occurrences of each unique element in `vec`. + +# This is useful when you want to subset another array based on limited occurrences per group. + +# # Arguments +# - `vec`: A vector of values to group by (e.g., transcription factors). +# - `N`: The maximum number of elements to select per unique group. + +# # Returns +# - A vector of indices corresponding to the first `N` entries per group. +# """ +# seen = Dict{eltype(vect), Int}() +# idxs = Int[] +# for (i, v) in enumerate(vect) +# seen[v] = get(seen, v, 0) + 1 +# if seen[v] <= N +# push!(idxs, i) +# end +# end +# return idxs +# end + +function rankEdges!(expressionData::GeneExpressionData, tfaData::PriorTFAData, grnData::GrnData, buildGrn::BuildGrn; + mergedTFsData::Union{mergedTFsResult, Nothing}=nothing, + useMeanEdgesPerGeneMode = true, + meanEdgesPerGene = 20, + correlationWeight = 1, + outputDir::Union{String, Nothing}=nothing) + + if isempty(outputDir) + error("You must provide a non-empty outputDir to save the results.") + end + + predictorMat = grnData.predictorMat + allPredictors = grnData.allPredictors + + totLambdas, totNetGenes, totNetTfs = size(grnData.stabilityMat) + totInts = totNetGenes * totNetTfs + targs = repeat(expressionData.targGenes, totNetTfs, 1) + regs = vec(repeat(permutedims(allPredictors), totNetGenes,1)) + rankTmp = buildGrn.networkStability[:] + keepInds = findall(x -> x != 0 && x != Inf, rankTmp) # keep nonzero, remove infinite values (e.g., corresponding to TF-TF edges when TF mRNA used for TFA) + rankTmp = rankTmp[keepInds] + inds = reverse(sortperm(rankTmp)) + rankings = sort(rankTmp,rev=true) + regs = regs[keepInds[inds]] + targs = targs[keepInds[inds]] + totInfInts = length(rankings) + + if useMeanEdgesPerGeneMode + totQuantEdges = length(unique(targs)) * meanEdgesPerGene + selectionIndices = 1:min(totQuantEdges, totInfInts) + else + selectionIndices = firstNByGroup(targs, meanEdgesPerGene) + end + + # Compute quantiles + F = ecdf(rankings[selectionIndices]) # uses the StatsBase package + quantiles = F.(rankings[selectionIndices]) + + ## take what's in the meanEdgesPerGene network and get partial correlations + allCoefs = zeros(totNetGenes,totNetTfs) + allQuants = zeros(totNetGenes,totNetTfs) + allStabsTest = buildGrn.networkStability + keptTargs = permutedims(targs[selectionIndices]) + uniTargs = unique(keptTargs) + totUniTargs = length(uniTargs) + tfsPerGene = zeros(totUniTargs,1) + + for targ = ProgressBar(1:totUniTargs) + currTarg = uniTargs[targ] + targRankInds = last.(Tuple.(findall(x -> x == currTarg, keptTargs))) + currRegs = regs[targRankInds] + + # --- Deduplicate currRegs while keeping aligned targRankInds --- + uniCurrRegsInds = firstNByGroup(currRegs, 1) # This reurns the index where each unique currRegs first appeared + uniCurrRegs = currRegs[uniCurrRegsInds] + targRankInds = targRankInds[uniCurrRegsInds] # Ensure the target corresponding to the duplicated Regulator is removed + # --- Get target index in full gene list + targInd = last.(Tuple.(findall(x -> x == currTarg, expressionData.targGenes))) + tfsPerGene[targ] = Int.(length(targRankInds)) + # tfsPerGene = Int.(tfsPerGene) + + # --- Match predictors --- + matchedRegs = intersect(allPredictors,currRegs) + # inds = findall(in(matchedRegs), allPredictors) + # regressIndsMat = first.(inds) + regressIndsMat = findall(in(matchedRegs), allPredictors) + rankVecInds = findall(in(matchedRegs),uniCurrRegs) + # --- Prepare input matrix: target + predictors --- + currTargVals = vec(transpose(grnData.responseMat[targInd,:])) + currPredVals = transpose(predictorMat[regressIndsMat,:]) + combTargPreds = vcat(currTargVals', currPredVals')' + combTargPreds = permutedims(combTargPreds') + prho = partialCorrelationMat(combTargPreds; first_vs_all = true) + prho = prho[2:end] + prho = vec(prho) + + if length(findall(x -> x == NaN, prho)) == 0 # make sure there weren't too many edges, + allCoefs[targInd,regressIndsMat] = prho + else + println(currTarg, " pcorr was singular, # TFs = ", string(length(regressIndsMat))) + end + + allQuants[targInd,regressIndsMat] = quantiles[targRankInds[rankVecInds]] + allStabsTest[targInd,regressIndsMat] = allStabsTest[targInd,regressIndsMat] + (correlationWeight .* transpose(round.(abs.(prho),digits = 4))) + end + + inPriorMat = sign.(abs.(grnData.priorMatProcessed)) + + mergeTfLocVec = zeros(totNetTfs) # for keeping track of merged TFs (needed for partial correlation calculation) + if (mergedTFsData !== nothing) && + (mergedTFsData.mergedPrior !== nothing) && + (mergedTFsData.mergedTFMap !== nothing) + + mergedTFs = mergedTFsData.mergedTFMap[:, 1] + individualTFs = mergedTFsData.mergedTFMap[:, 2] + totMerged = length(individualTFs) + + rmInds = [] # remove merged TFs from regulators + addRegs = [] + addInts = [] + addCoefs = [] + addPMat = [] + addPredMat = [] + addLoc = [] + addQuants = [] + for mind = 1:totMerged + mTf = mergedTFs[mind] + inputLocs = findall(x -> x == mTf,allPredictors) + totMInts = length(inputLocs) # number of interactions for merged TF + rmInds = [rmInds; inputLocs]; + if length(inputLocs) > 0 + indTfs = intersect(permutedims(split(individualTFs[mind],", ")),tfaData.pRegsNoTfa); # intersect ensures that TF was a potential regulator (e.g., based on gene expression) + totIndTfs = length(indTfs) + for indt = 1:totIndTfs + indTf = indTfs[indt] + append!(addRegs, fill(indTf, 1)) + append!(addInts, allStabsTest[:,inputLocs]) + append!(addPMat, inPriorMat[:,inputLocs]) + append!(addQuants, allQuants[:,inputLocs]) + append!(addCoefs, allQuants[:,inputLocs]) + append!(addPredMat, predictorMat[inputLocs,:]) + append!(addLoc, mind) + end + end + end + println("Total of ", string(length(rmInds)), " TFs expanded.") + keepInds = setdiff(1:totNetTfs,rmInds) + # remove merged TFs and add individual TFs + if length(addRegs) > 0 + allPredictors = collect(allPredictors[keepInds]) + append!(allPredictors, addRegs) + allStabsTest = hcat(allStabsTest[:,keepInds], reshape(addInts, size(allStabsTest)[1], Int(size(addInts)[1] / size(allStabsTest)[1]))) + allCoefs = hcat(allCoefs[:,keepInds], reshape(addCoefs, size(allStabsTest)[1], Int(size(addInts)[1] / size(allStabsTest)[1]))) + allQuants = hcat(allQuants[:,keepInds], reshape(addQuants, size(allStabsTest)[1], Int(size(addInts)[1] / size(allStabsTest)[1]))) + inPriorMat = hcat(inPriorMat[:,keepInds], reshape(addPMat, size(inPriorMat)[1], Int(size(addPMat)[1] / size(inPriorMat)[1]))) + predictorMat = vcat(predictorMat[keepInds, :], reshape(addPredMat, Int((size(addPredMat)[1]) / size(predictorMat)[2]), size(predictorMat)[2])) + append!(mergeTfLocVec, addLoc) + end + else + println("No merged TFs file found.") + end + + ## re-rank based on possibly de-merged TFs + rankings = allStabsTest[:] + coefVec = allCoefs[:] + + + inPriorVec = inPriorMat[:] + totNetTfs = length(allPredictors) + totInts = totNetGenes * totNetTfs + targs = repeat((expressionData.targGenes),totNetTfs,1) + regs1 = repeat(permutedims(allPredictors),totNetGenes,1) + regs = reshape(regs1,totInts,1) + + rankings = convert(Vector{Float64}, rankings) + coefVec = convert(Vector{Float64}, coefVec) + inPriorVec = convert(Vector{Float64}, inPriorVec) + predictorMat = convert(Matrix{Float64}, predictorMat) + allStabsTest = convert(Matrix{Float64}, allStabsTest) + allCoefs = convert(Matrix{Float64}, allCoefs) + allQuants = convert(Matrix{Float64}, allQuants) + inPriorMat = convert(Matrix{Float64}, inPriorMat) + + ## only keep nonzero rankings + keepRankings = findall(x -> x != 0 && x != Inf, rankings) + indsMerged = sortperm(rankings[keepRankings]) + indsMerged = reverse(indsMerged) + + # update info sources + rankings = rankings[keepRankings[indsMerged]] + coefVec = coefVec[keepRankings[indsMerged]] + inPriorVec = inPriorVec[keepRankings[indsMerged]] + regs = regs[keepRankings[indsMerged]] + targs = targs[keepRankings[indsMerged]] + totInfInts = length(rankings) + + # totQuantEdges = length(unique(targs))*meanEdgesPerGene + # Compute quantiles + F = ecdf(rankings) + quantiles = F.(rankings) + + # Compute signedQuantiles + signedQuantile = sign.(coefVec) .* quantiles + + ## ------- Color calculations for JP-Gene-Viz + # Compute strokeWidth and colors + minRank, maxRank = extrema(rankings) + rankRange = max(maxRank - minRank, eps()) # prevent division by zero + # Dash Styling + strokeDashArray = ifelse.(inPriorVec .!= 0, "Yes", "No") + networkMatrix = hcat(regs, targs, signedQuantile, rankings, coefVec, strokeDashArray) + if useMeanEdgesPerGeneMode + totQuantEdges = length(unique(targs)) * meanEdgesPerGene + # selectionIndices = 1:totQuantEdges + selectionIndices = 1:min(totQuantEdges, totInfInts) + else + selectionIndices = firstNByGroup(targs, meanEdgesPerGene) + + end + + networkMatrixSubset = hcat(regs[selectionIndices], targs[selectionIndices], signedQuantile[selectionIndices], + rankings[selectionIndices], coefVec[selectionIndices], strokeDashArray[selectionIndices] + ) + + buildGrn.regs = regs + buildGrn.targs = targs + buildGrn.signedQuantile = signedQuantile + buildGrn.rankings = rankings + buildGrn.partialCorrelation = coefVec + buildGrn.inPrior = strokeDashArray + buildGrn.networkMat = networkMatrix + buildGrn.networkMatSubset = networkMatrixSubset + buildGrn.inPriorVec = inPriorVec + buildGrn.mergeTfLocVec = mergeTfLocVec # just added + + + if outputDir !== nothing && outputDir !== "" + outputFile = joinpath(outputDir, "grnOutMat.jld") + save_object(outputFile, buildGrn) + end +end + +# function saveData(expressionData::GeneExpressionData, tfaData::PriorTFAData, grnData::GrnData, buildGrn::BuildGrn, outputDir::String, fileName::String) +# output = outputDir * "/" * fileName +# @save output expressionData tfaData grnData buildGrn +# end + +# function writeNetworkTable!(buildGrn::BuildGrn; outputDir::String, networkName::Union{String, Nothing}=nothing) +# local baseName = (networkName === nothing || isempty(networkName)) ? "edges" : networkName * "edges" + +# outputFile = joinpath(outputDir, baseName * ".tsv") +# colNames = "TF\tGene\tsignedQuantile\tStability\tCorrelation\tinPrior\n" +# open(outputFile, "w") do io +# write(io, colNames) +# writedlm(io, buildGrn.networkMat) +# end + +# outputFileSubset = joinpath(outputDir, baseName * "_subset.tsv") +# open(outputFileSubset, "w") do io +# write(io, colNames) +# writedlm(io, buildGrn.networkMatSubset) +# end +# end diff --git a/src/grn/PrepareGRN.jl b/src/grn/PrepareGRN.jl new file mode 100755 index 0000000..d3cb15f --- /dev/null +++ b/src/grn/PrepareGRN.jl @@ -0,0 +1,454 @@ +""" +GRN + +Functions: +- preparePredictorMat! +- preparePenaltyMatrix! +- constructSubsamples +- bstarsWarmStart +- bstartsEstimateInstability + +Dependencies: +""" + +# mutable struct GrnData +# predictorMat::Matrix{Float64} +# penaltyMat::Matrix{Float64} +# allPredictors::Vector{String} +# subsamps::Matrix{Int64} +# responseMat::Matrix{Float64} +# maxLambdaNet::Float64 +# minLambdaNet::Float64 +# minLambdas::Matrix{Float64} +# maxLambdas::Matrix{Float64} +# netInstabilitiesUb::Vector{Float64} +# netInstabilitiesLb::Vector{Float64} +# instabilitiesUb::Matrix{Float64} +# instabilitiesLb::Matrix{Float64} +# netInstabilities::Vector{Float64} +# geneInstabilities::Matrix{Float64} +# lambdaRange::Vector{Float64} +# lambdaRangeGene::Vector{Vector{Float64}} +# stabilityMat::Array{Float64} +# priorMatProcessed::Matrix{Float64} +# betas::Array{Float64,3} +# function GrnData() +# return new( +# Matrix{Float64}(undef, 0, 0), # predictorMat +# Matrix{Float64}(undef, 0, 0), # penaltyMat +# [], # allPredictors +# Matrix{Int64}(undef, 0, 0), # subsamps +# Matrix{Int64}(undef, 0, 0), # responseMat +# 0.0, # maxLambdasNet +# 0.0, # minLambdasNet +# Matrix{Int64}(undef, 0, 0), # minLambdas +# Matrix{Int64}(undef, 0, 0), # maxLambdas +# [], # netInstabilitiesUb +# [], # netInstabilitiesLb +# Matrix{Int64}(undef, 0, 0), # instabilitiesUb +# Matrix{Int64}(undef, 0, 0), # instabilitiesLb +# [], # netInstabilities +# Matrix{Int64}(undef, 0, 0), # geneInstabilities +# [], # lambdaRange +# Vector{Vector{Float64}}(undef, 0), # lambdaRangesGene +# Matrix{Int64}(undef, 0, 0), # stabilityMat +# Matrix{Float64}(undef, 0, 0), # priorMatProcessed +# Array{Float64,3}(undef, 0, 0, 0) # betas +# ) +# end +# end + +function preparePredictorMat!(grnData::GrnData, expressionData::GeneExpressionData, priorData::PriorTFAData; tfaOpt::String = "") + if tfaOpt != "" + println("noTfa option") + pRegs = priorData.pRegsNoTfa; + pTargs = priorData.pTargsNoTfa; + priorMatrix = priorData.priorMatrixNoTfa; + else + pRegs = priorData.pRegs; + pTargs = priorData.pTargs; + priorMatrix = priorData.priorMatrix; + end + # TFs in potRegs that have expression data that arent in prior. Get the index for these TFs + # in the TF expression matrix + uniNoPriorRegs = setdiff(expressionData.potRegsmRNA, pRegs) + uniNoPriorRegInds = findall(in(uniNoPriorRegs), expressionData.potRegsmRNA) + + # allPredictors include both the pRegs and the potRegs that wernt in prior + allPredictors = vcat(pRegs, uniNoPriorRegs) + totPreds = length(allPredictors) + + # Create new prior matrix that contains target genes in the same order as targGenes. If using + # TFA, missing TFs not in prior that we have expression data for + targGeneInds = findall(in(pTargs), expressionData.targGenes) + priorGeneInds = findall(in(expressionData.targGenes), pTargs) + totTargGenes = length(expressionData.targGenes) + totPRegs = length(pRegs) + priorMat = zeros(totTargGenes,totPreds) + priorMat[targGeneInds,1:totPRegs] = priorMatrix[priorGeneInds,:] + + # # predictorMat will be TFA (when available) and TFmRNA when TFA not available + # predictorMat = [priorData.medTfas; expressionData.potRegMatmRNA[uniNoPriorRegInds,:]] + # # If not using TFA, just set predictorMat to mRNA + # if tfaOpt != "" # use the mRNA levels of TFs + # currPredMat = zeros(totPreds,length(expressionData.cellLabels)) + # for prend = 1:totPreds + # prendInd = findall(x -> x==allPredictors[prend],expressionData.potRegsmRNA) + # currPredMat[prend,:] = expressionData.potRegMatmRNA[prendInd,:] + # end + # predictorMat = currPredMat + # println("TF mRNA used.") + # end + + # predictorMat will be TFA (when available) and TFmRNA when TFA not available + if tfaOpt == "" + # TFA: stack TFA estimates on top of mRNA for TFs not in prior + predictorMat = [priorData.medTfas; expressionData.potRegMatmRNA[uniNoPriorRegInds,:]] + else + # If not using TFA: use mRNA for all predictors + currPredMat = zeros(totPreds, length(expressionData.cellLabels)) + for prend = 1:totPreds + prendInd = findall(x -> x == allPredictors[prend], expressionData.potRegsmRNA) + currPredMat[prend, :] = expressionData.potRegMatmRNA[prendInd, :] + end + predictorMat = currPredMat + println("TF mRNA used.") + end + grnData.predictorMat = predictorMat + grnData.allPredictors = allPredictors + grnData.responseMat = expressionData.targGeneMat + grnData.priorMatProcessed = priorMat +end + + +function preparePenaltyMatrix!(expressionData::GeneExpressionData, grnData::GrnData; + priorFilePenalties::Vector{String} = String[], + lambdaBias::Vector{Float64} = [0.5], + tfaOpt::String = "") + #1. Update Penalty Matrix + # Create a dictionary to store the minimum lambda for each interaction + minLambdaDict = Dict{Tuple{String,String}, Float64}() + penaltyMatrix = ones(length(expressionData.targGenes),length(grnData.allPredictors)) + + # Iterate through each prior file and its associated lambda + for (filePath, lambda) in zip(priorFilePenalties, lambdaBias) + # Read the prior file + priorData = readdlm(filePath) + # priorData = CSV.read(filePath, DataFrame; delim = "\t") + + # Extract gene names and TF names from the prior file + priorGenes = priorData[2:end, 1] + # Get TF names, filtering out any empty or missing entries + priorTFs = filter(tf -> !ismissing(tf) && !isempty(string(tf)), priorData[1, :]) + + # Create indices mapping for faster lookup + geneToIdx = Dict(gene => i for (i, gene) in enumerate(expressionData.targGenes)) + tfToIdx = Dict(tf => i for (i, tf) in enumerate(grnData.allPredictors)) + # Process the interactions + for (i, gene) in enumerate(priorGenes) + for (j, tf) in enumerate(priorTFs) + # println(" Trying ($gene, $tf): ") + if priorData[i+1, j+1] != 0 && haskey(geneToIdx, gene) && haskey(tfToIdx, tf) + interaction = (gene, tf) + if !haskey(minLambdaDict, interaction) || lambda < minLambdaDict[interaction] + minLambdaDict[interaction] = lambda + end + end + end + end + end + + # Apply the penalties using the minimum lambda values + for ((gene, tf), minLambda) in minLambdaDict + geneIdx = findfirst(==(gene), expressionData.targGenes) + tfIdx = findfirst(==(tf), grnData.allPredictors) + penaltyMatrix[geneIdx, tfIdx] = minLambda + end + + # 2. + totPreds = length(grnData.allPredictors) + if tfaOpt !== "" + ## set lambda penalty to infinity for positive feedback edges where TF + # mRNA levels serves both as gene expression and TFA estimate + for pr = 1:totPreds # Changed this from length(expressionData.potRegs) to length(grnData.allPredictors) + targInd = findall(x -> x==grnData.allPredictors[pr], expressionData.targGenes) + if length(targInd) > 0 # set lambda penalty to infinity, avoid predicting a TF's mRNA based on its own mRNA level + penaltyMatrix[targInd,pr] .= Inf # i.e., target gene is its own predictor + end + end + else # have to set prior inds to zero for TFs in TFA that don't have prior info + for pr = 1:totPreds # Changed this from expressionData.potRegs to grnData.allPredictors + if sum(abs.(grnData.priorMatProcessed[:,pr])) == 0 # we have no target edges to estimate TF's TFA + targInd = findall(x -> x==grnData.allPredictors[pr], expressionData.targGenes) + if length(targInd) > 0 # And TF is in the predictor set + penaltyMatrix[targInd,pr] .= Inf + end + end + end + end + + grnData.penaltyMat = penaltyMatrix +end + + +function constructSubsamples(expressionData::GeneExpressionData, grnData::GrnData; + leaveOutSampleList::Union{Vector{Vector{String}}, Nothing}=nothing,totSS = 200, subsampleFrac = 0.63) + totSamps = length(expressionData.cellLabels) + + if !(leaveOutSampleList in (nothing, "")) + println("Leave-out set detected: ", leaveOutSampleList) + # get leave-out set of samples + fin = open(leaveOutSampleList) + C = readdlm(fin,skipstart=0) + C = convert(Matrix{String}, C) + close(fin) + testInds = Tuple.(findall(in(C), expressionData.cellLabels)) + testInds = first.(testInds) + trainInds = setdiff(1:totSamps,testInds) + else + println("Full gene expression matrix used.") + trainInds = 1:totSamps # all training samples used + testInds = [] + end + + subsampleSize = floor(subsampleFrac*length(trainInds)) + subsampleSize = convert(Int64, subsampleSize) + # get subsamp indices + subsamps = zeros(totSS,subsampleSize) + for ss = 1:totSS + randSubs = randperm(totSamps) + randSubs = randSubs[1:subsampleSize] + subsamps[ss,:] = randSubs + end + subsamps = convert(Matrix{Int}, subsamps) + grnData.subsamps = subsamps +end + +function bstarsWarmStart(expressionData::GeneExpressionData, tfaData::PriorTFAData, grnData::GrnData; minLambda = 0.01, maxLambda = 0.5, totLambdasBstars = 20, totSS = 5, targetInstability = 0.05, zTarget = false) + # Determine the lambda levels to test + lambdaRange = collect(range(minLambda, stop = maxLambda, length = totLambdasBstars)) + #lamLog10step = 1/totLambdasBstars + #logLamRange = log10(minLambda):lamLog10step:log10(maxLambda) + #lambdaRange = 10 .^ (logLamRange) + lambdaRange = reverse(lambdaRange) + totResponses = size(grnData.responseMat)[1] + + instabilitiesLb = zeros(totResponses,totLambdasBstars) + instabilitiesUb = zeros(totResponses,totLambdasBstars) + minLambdas = zeros(totResponses,1) + maxLambdas = zeros(totResponses,1) + + responsePredInds = Vector{Vector{Int}}(undef,0) + for res = 1:totResponses + currWeights = grnData.penaltyMat[res,:] + push!(responsePredInds,findall(x -> x!=Inf, currWeights)) + end + + netInstabilitiesLb = zeros(totResponses, totLambdasBstars) + netInstabilitiesUb = zeros(totResponses, totLambdasBstars) + totEdges = zeros(totResponses) # denominator for network Instabilities + + theta2save = Vector{Matrix{Float64}}(undef, totResponses) # Debugging # + + Threads.@threads for res in ProgressBar(1:totResponses) # can be a parfor loop + predInds = responsePredInds[res] + currPredNum = length(predInds) + penaltyFactor = grnData.penaltyMat[res, predInds] + totEdges[res] = currPredNum + ssVals = zeros(totLambdasBstars,currPredNum) + for ss = 1:totSS + subsamp = grnData.subsamps[ss,:] + dt = fit(ZScoreTransform, grnData.predictorMat[predInds, subsamp], dims=2) + currPreds = transpose(StatsBase.transform(dt, grnData.predictorMat[predInds, subsamp])) + if zTarget + dt = fit(ZScoreTransform, grnData.responseMat[res, subsamp], dims=1) + currResponses = StatsBase.transform(dt, grnData.responseMat[res, subsamp]) + else + currResponses = grnData.responseMat[res, subsamp] + end + lsoln = glmnet(currPreds, currResponses, penalty_factor = penaltyFactor, lambda = lambdaRange, alpha = 1.0) + currBetas = lsoln.betas # flip so that the lambdas are increasing + # ssVals += abs.(sign.(currBetas))' + ssVals .+= abs.(sign.(currBetas))' + end + theta2 = (1/totSS)*ssVals # empirical edge probability + theta2save[res] = theta2 #For Debugging # + instabilitiesLb[res,:] = 2 * (1/currPredNum) .* sum(theta2 .* (1 .- theta2), dims=2) # bStARS lower bound + netInstabilitiesLb[res,:] = currPredNum*(instabilitiesLb[res,:]) + theta2mean = sum(theta2,dims=2)./currPredNum + instabilitiesUb[res,:] = 2 * theta2mean .* (1 .- theta2mean) # bStARS upper bound + netInstabilitiesUb[res,:] = currPredNum*instabilitiesUb[res,:] + end + + totEdges = sum(totEdges) + netInstabilitiesLb = sum(netInstabilitiesLb, dims=1)[:] + netInstabilitiesUb = sum(netInstabilitiesUb, dims=1)[:] + + for res = 1:totResponses + # take the supremum, find max Lambda, and set all smaller lambdas equal to that value + maxLb = findmax(instabilitiesLb[res,:]) + maxLbInd = findall(x -> x == maxLb[1], instabilitiesLb[res,:]) + maxLb = maxLb[1] + instabilitiesLb[res,maxLbInd[end]:end] .= maxLb + maxUb = findmax(instabilitiesUb[res,:]) + maxUbInd = findall(x -> x == maxUb[1], instabilitiesUb[res,:]) + maxUb = maxUb[1] + instabilitiesUb[res,maxUbInd[end]:end] .= maxUb + # find the minimum lambda for the gene, based on maximum for upper bound + # we are less interested in high instability lambdas, so okay to use + # upper bound + xx = findmin(abs.(instabilitiesLb[res,:] .- targetInstability)) + #xx = findall(x -> x == xx[1],abs.(instabilitiesLb[res,:] .- targetMaxInstability)) + xx = xx[2] + minLambdas[res] = (lambdaRange)[xx[end]] # to the right + # find the lambda nearest the min instability worth considering, use + # upper bound as that will be sure to find an lambda >= target instability lambda + xx = findmin(abs.(instabilitiesUb[res,:] .- targetInstability)) + xx = findall(x -> x == xx[1], abs.(instabilitiesUb[res,:] .- targetInstability)) + maxLambdas[res] = (lambdaRange)[xx[end]] # to the right + # note for typical bStARS, where you know what instability cutoff you + # want you'd use the upperbound to find the min lambda and the lb to + # find the max lambda + end + + netInstabilitiesUb = netInstabilitiesUb ./ totEdges + netInstabilitiesLb = netInstabilitiesLb ./ totEdges + maxLb = findmax(netInstabilitiesLb) + maxLbInd = findall(x -> x == maxLb[1], netInstabilitiesLb) + netInstabilitiesLb[maxLbInd[end]:end] .= maxLb[1] # take supremum for lambdas smaller than instability max + maxUb = findmax(netInstabilitiesUb) + maxUbInd = (findall(x -> x == maxUb[1], netInstabilitiesUb)) + netInstabilitiesUb[maxUbInd[end]:end] .= maxUb[1] + xx = findmin(abs.(netInstabilitiesLb .- targetInstability)) + maxInstInd = findall(x -> x == xx[1], abs.(netInstabilitiesLb .- targetInstability)) + minLambdaNet = (lambdaRange)[maxInstInd[end]] + xx = findmin(abs.(netInstabilitiesUb .- targetInstability)) + minInstInd = findall(x -> x == xx[1], abs.(netInstabilitiesUb .- targetInstability)) + maxLambdaNet = (lambdaRange)[minInstInd[end]] + + grnData.minLambdaNet = minLambdaNet + grnData.maxLambdaNet = maxLambdaNet + grnData.maxLambdas = maxLambdas + grnData.minLambdas = minLambdas + grnData.netInstabilitiesLb = netInstabilitiesLb + grnData.netInstabilitiesUb = netInstabilitiesUb + grnData.instabilitiesLb = instabilitiesLb + grnData.instabilitiesUb = instabilitiesUb +end + +function bstartsEstimateInstability(grnData::GrnData; totLambdas = 10, instabilityLevel = "Gene", zTarget = false, outputDir::Union{String, Nothing}=nothing) + + totResponses,totSamps = size(grnData.responseMat) # totResponses is same as length(grnData["targGenes"]) + totPreds = size(grnData.predictorMat,1) + + # if instabilityLevel == "Gene" + # minLambda = minimum(grnData.minLambdas) + # maxLambda = maximum(grnData.maxLambdas) + # elseif instabilityLevel == "Network" + # minLambda = grnData.minLambdaNet + # maxLambda = grnData.maxLambdaNet + # else + # println("Use either 'Gene' or 'Network' for instabilityLevel") + # end + + # Gene Level Instabilities + lambdaRangeGene = Vector{Vector{Float64}}(undef, totResponses) + for res in 1:totResponses + λmin = grnData.minLambdas[res] + λmax = grnData.maxLambdas[res] + lambdaRangeGene[res] = reverse(collect(range(λmin, stop=λmax, length=totLambdas))) + end + grnData.lambdaRangeGene = lambdaRangeGene + + # Network Level Instability + minLambda = grnData.minLambdaNet + maxLambda = grnData.maxLambdaNet + + lambdaRange = collect(range(minLambda, stop = maxLambda, length = totLambdas)) + lambdaRange = reverse(lambdaRange) + grnData.lambdaRange = lambdaRange + + geneInstabilities = zeros(totResponses,totLambdas) + netInstabilities = zeros(totLambdas,1) + totEdges = 0 # denominator for network Instabilities + # store number of subsamples for which an edge was nonzero, given that some + # prior weights can be set to infinity, track to make sure these edges are + # not counted + ssMatrix = Inf*ones(totLambdas,totResponses,totPreds) + subsamps = grnData.subsamps + totSS = size(subsamps)[1] + + # get (finite) predictor indices for each response + responsePredInds = Vector{Vector{Int}}(undef,0) + for res = 1:totResponses + currWeights = grnData.penaltyMat[res,:] + push!(responsePredInds,findall(x -> x!=Inf, currWeights)) + end + + totEdges = zeros(totResponses) + betas = Array{Float64,3}(undef, totResponses, totPreds, totLambdas) + Threads.@threads for res in ProgressBar(1:totResponses) + lambdaRange = instabilityLevel == "Network" ? + grnData.lambdaRange : # single shared vector + grnData.lambdaRangeGene[res] # per‑gene vector + predInds = responsePredInds[res] + currPredNum = length(predInds) + totEdges[res] = currPredNum + penaltyFactor = grnData.penaltyMat[res,predInds] + ssVals = zeros(totLambdas,currPredNum) + for ss = 1:totSS + subsamp = subsamps[ss,:] + dt = fit(ZScoreTransform, grnData.predictorMat[predInds, subsamp], dims=2) + currPreds = transpose(StatsBase.transform(dt, grnData.predictorMat[predInds, subsamp])) + if zTarget + dt = fit(ZScoreTransform, grnData.responseMat[res, subsamp], dims=1) + currResponses = StatsBase.transform(dt, grnData.responseMat[res, subsamp]) + else + currResponses = grnData.responseMat[res, subsamp] + end + lsoln = glmnet(currPreds, currResponses, penalty_factor = penaltyFactor, lambda = lambdaRange, alpha = 1.0) + currBetas = lsoln.betas + betas[res,predInds, :] = currBetas + ssVals = ssVals + abs.(sign.(currBetas))' + end + ssMatrix[:,res,predInds] = ssVals + end + + for res = 1:totResponses + currWeights = grnData.penaltyMat[res,:] + predInds = responsePredInds[res] + currPredNum = length(predInds) + ssVals = zeros(totLambdas,currPredNum) + ssVals[:,:] = ssMatrix[:,res,predInds] + theta2 = (1/totSS)*ssVals # empirical edge probability, lambdas X currPreds + instabilitiesPerEdge = 2*(theta2 .* (1 .- theta2)) + aveInstabilities = mean(instabilitiesPerEdge,dims=2) # lambdas X 1 + maxUb = (findmax(aveInstabilities)) + instabSUP = aveInstabilities + maxUbInd = first(Tuple(maxUb[2])) + instabSUP[maxUbInd:end,1] .= maxUb[1] + geneInstabilities[res,:] = instabSUP + end + + ## calculate instabilities network-wise + currSS = zeros(totResponses,totPreds) + instabRange = zeros(totLambdas,1) + for lind = 1:totLambdas # start at highest lambda (lowest instability and work down) + currSS[:,:] = ssMatrix[lind,:,:] + theta2 = (1/totSS)*currSS # empirical edge probability, responses X currPreds + instabilitiesPerEdge = 2*(theta2 .* (1 .- theta2)) + instabVec = instabilitiesPerEdge[:] + validEdges = findall(isfinite.(currSS[:])) # limit to finite edges + instabMax = findmax(instabRange)[1] + netInstabilities[lind] = max(mean(instabVec[validEdges]),max(instabMax)) + end + grnData.netInstabilities = vec(netInstabilities) + grnData.geneInstabilities = geneInstabilities + grnData.stabilityMat = ssMatrix + grnData.betas = betas + + if outputDir !== nothing && outputDir !== "" + outputFile = joinpath(outputDir, "instabOutMat.jld") + save_object(outputFile, grnData) + end +end diff --git a/src/grn/RefineTFA.jl b/src/grn/RefineTFA.jl new file mode 100755 index 0000000..611a935 --- /dev/null +++ b/src/grn/RefineTFA.jl @@ -0,0 +1,94 @@ +""" +GRN + +Main function: +- `refineTFA`: Combines GRNs using merged TFs, TFA data, and other metadata. + +Dependencies: +- `Data.GeneExpressionData` and `Data.PriorTFAData` +- `Prior.mergeDegenerateTFs` +""" + # Access dependencies from the package; no need to include files again + # using ..Data # For GeneExpressionData + # using ..PriorTFA # For PriorTFAData + # using ..MergeDegenerate + + # # Other standard packages + # using PyPlot + # using Statistics + # using CSV + # using DelimitedFiles + # using JLD2 + # using NamedArrays + + # Export only the function users need + # export combineGRNS2 + +function refineTFA(data::GeneExpressionData, mergedTFsData::mergedTFsResult; + priorFile::String = "", + tfaGeneFile::String = "", + edgeSS::Int = 0, + minTargets::Int = 3, + zTarget::Bool = true, + geneExprFile::String = "", + targFile::String = "", + regFile::String = "", + outputDir::Union{String, Nothing} = nothing) + if !isnothing(outputDir) + mkpath(outputDir) + end + + # Ensure required expression data are available and non-empty + requiredFields = [ + :tfaGenes, + :tfaGeneMat, + :potRegs, + :potRegsmRNA, + :potRegMatmRNA + ] + + missingFields = [field for field in requiredFields if isempty(getfield(data, field))] + + if !isempty(missingFields) + println("Missing or empty fields in data: ", missingFields) + println("Generating required data by loading from input files...") + + requiredFiles = [ + ("Gene Expression File", geneExprFile), + ("Target Gene File", targFile), + ("Potential Regulators File", regFile), + ] + + missingFiles = [name for (name, path) in requiredFiles if isempty(path) || !isfile(path)] + if !isempty(missingFiles) + error("Cannot generate data. Missing input files: ", missingFiles) + end + + data = GeneExpressionData() + loadExpressionData!(data, geneExprFile) + loadAndFilterTargetGenes!(data, targFile; epsilon=0.01) + loadPotentialRegulators!(data, regFile) + processTFAGenes!(data, tfaGeneFile; outputDir = outputDir) + end + + + # 2. Integrate prior information for TFA estimation + tfaData = PriorTFAData() + processPriorFile!(tfaData, data, priorFile; mergedTFsData, minTargets = minTargets); + calculateTFA!(tfaData, data; edgeSS = edgeSS, zTarget = zTarget, outputDir = outputDir); + + # Save TFA as a text file# Save median TFA if outputDir is specified + if !isnothing(outputDir) + namedMedTFA = NamedArray(tfaData.medTfas) + setnames!(namedMedTFA, tfaData.pRegs, 1) + setnames!(namedMedTFA, data.cellLabels, 2) + + outputfile = joinpath(outputDir, "TFA.txt") + open(outputfile, "w") do io + writedlm(io, permutedims(data.cellLabels)) + writedlm(io, namedMedTFA) + end + end + + return tfaData +end diff --git a/src/grn/UtilsGRN.jl b/src/grn/UtilsGRN.jl new file mode 100755 index 0000000..548ef29 --- /dev/null +++ b/src/grn/UtilsGRN.jl @@ -0,0 +1,54 @@ +function firstNByGroup(vect::AbstractVector, N::Integer) + """ + firstNByGroupIndices(vec::AbstractVector, N::Integer) + + Return the indices of the first `N` occurrences of each unique element in `vec`. + + This is useful when you want to subset another array based on limited occurrences per group. + + # Arguments + - `vec`: A vector of values to group by (e.g., transcription factors). + - `N`: The maximum number of elements to select per unique group. + + # Returns + - A vector of indices corresponding to the first `N` entries per group. + """ + seen = Dict{eltype(vect), Int}() + idxs = Int[] + for (i, v) in enumerate(vect) + seen[v] = get(seen, v, 0) + 1 + if seen[v] <= N + push!(idxs, i) + end + end + return idxs +end + + + +# function firstNByGroup(vect::AbstractVector, N::Integer) +# """ +# firstNByGroup(vect::AbstractVector, N::Integer) + +# Return a boolean mask selecting the first `N` occurrences of each unique element in `vec`. + +# This is useful for logical indexing when subsetting arrays with the same length as `vec`. + +# # Arguments +# - `vect`: A vector of values to group by. +# - `N`: The maximum number of elements to select per group. + +# # Returns +# - A boolean mask vector with `true` at positions of the first `N` occurrences per group. +# """ +# seen = Dict{eltype(vect), Int}() +# mask = falses(length(vect)) +# for i in eachindex(vect) +# seen[vect[i]] = get(seen, vect[i], 0) + 1 +# mask[i] = seen[vect[i]] <= N +# end +# return mask +# end + + + diff --git a/src/metrics/CalcPR.jl b/src/metrics/CalcPR.jl new file mode 100755 index 0000000..5a904db --- /dev/null +++ b/src/metrics/CalcPR.jl @@ -0,0 +1,658 @@ +""" + computeMacroMetrics(gsPrecisionsByTf, gsRecallsByTf, gsFprsByTf, gsAuprsByTf, gsArocsByTf; + target_points=1000, min_step=1e-4, step_method=:min_gap) + +Compute macro-averaged PR and ROC curves across all TFs using interpolation. + +Per-TF curves are interpolated onto a common grid and averaged. Only TFs with +non-empty, non-zero recall arrays contribute to the average. + +# Arguments +- `gsPrecisionsByTf` : Vector of per-TF precision arrays. +- `gsRecallsByTf` : Vector of per-TF recall arrays. +- `gsFprsByTf` : Vector of per-TF false positive rate arrays. +- `gsAuprsByTf` : Vector of per-TF AUPR values. +- `gsArocsByTf` : Vector of per-TF AUROC values. + +# Keyword Arguments +- `target_points::Int=1000` : Number of interpolation points if `step_method=:target_points`. +- `min_step::Float64=1e-4` : Minimum allowable step size for interpolation. +- `step_method::Symbol=:min_gap` : Step selection method — `:min_gap` or `:target_points`. + +# Returns +An `OrderedDict` with two nested `OrderedDict`s: +- `:macroPR` → `:auprInterpolated`, `:precisions`, `:recalls` +- `:macroROC` → `:aurocInterpolated`, `:fprs`, `:tprs` + +# Example +```julia +results = computeMacroMetrics(precsByTf, recsByTf, fprsByTf, auprsByTf, aurocsByTf) +results[:macroPR][:auprInterpolated] # macro AUPR +results[:macroPR][:precisions] # macro precision curve +results[:macroROC][:tprs] # macro TPR curve +``` +""" +function computeMacroMetrics(gsPrecisionsByTf, gsRecallsByTf, gsFprsByTf, gsAuprsByTf, gsArocsByTf; + target_points=1000, + min_step=1e-4, + step_method=:min_gap) + + # --- Scalar Macro Metrics --- + macroAUPR_scalar = mean(filter(isfinite, gsAuprsByTf)) + macroAUROC_scalar = mean(filter(isfinite, gsArocsByTf)) + + + # --- Interpolated Macro Curves --- + + # Determine dynamic max recall/FPR + allRecalls = reduce(vcat, filter(!isempty, gsRecallsByTf)) + allRecalls = filter(x -> isfinite(x), allRecalls) + allFprs = reduce(vcat, filter(!isempty, gsFprsByTf)) + allFprs = filter(x -> isfinite(x), allFprs) + + maxRecall = isempty(allRecalls) ? 0.0 : maximum(allRecalls) + maxFpr = isempty(allFprs) ? 0.0 : maximum(allFprs) + + recStep = adaptiveStep(gsRecallsByTf; target_points=target_points, min_step=min_step, method=step_method) + fprStep = adaptiveStep(gsFprsByTf; target_points=target_points, min_step=min_step, method=step_method) + recInterpPts = 0.0:recStep:maxRecall # Interpolation points + fprInterpPts = 0.0:fprStep:maxFpr + + macroPrecisions = zeros(length(recInterpPts)) + macroTprs = zeros(length(fprInterpPts)) + nTFs = length(gsPrecisionsByTf) + validTFs = [i for i in 1:nTFs if !isempty(gsRecallsByTf[i]) && maximum(gsRecallsByTf[i]) > 0] + for i in validTFs + # PR Curve + interpPrec = dedupNInterpolate(gsRecallsByTf[i], gsPrecisionsByTf[i], recInterpPts) + macroPrecisions .+= interpPrec + # ROC Curve # This is needed for ROC curves. Here Tprs != Recall but interpolated Recall on Fprs + interpTpr = dedupNInterpolate(gsFprsByTf[i], gsRecallsByTf[i], fprInterpPts) + macroTprs .+= interpTpr + end + + macroPrecisions ./= length(validTFs) + macroTprs ./= length(validTFs) + macroAUPR_interpolated = sum(0.5 .* (macroPrecisions[2:end] + macroPrecisions[1:end-1]) .* diff(recInterpPts)) + macroAUROC_interpolated = sum(0.5 .* (macroTprs[2:end] + macroTprs[1:end-1]) .* diff(fprInterpPts)) + + macroResults = OrderedDict( + :macroPR => OrderedDict( + # :auprScalar => macroAUPR_scalar, + :auprInterpolated => macroAUPR_interpolated, + :precisions => macroPrecisions, + :recalls => collect(recInterpPts) + ), + :macroROC => OrderedDict( + # :aurocScalar => macroAUROC_scalar, + :aurocInterpolated => macroAUROC_interpolated, + :fprs => collect(fprInterpPts), + :tprs => macroTprs + ) +) + return macroResults +end + +""" + evaluatePerTF(regs, targs, rankings, infEdges, gsEdgesByTf, uniGsRegs, gsRandPRbyTf; + saveDir=nothing, breakTies=true, target_points=1000, min_step=1e-4, + step_method=:min_gap, xLimitRecall=1.0) + +Compute per-TF precision-recall and ROC metrics for an inferred gene regulatory network (GRN) +against a gold standard, and optionally plot per-TF PR/ROC curves. + +# Arguments +- `regs::Vector{<:AbstractString}` : Regulators for each inferred edge. +- `targs::Vector{<:AbstractString}` : Target genes for each inferred edge. +- `rankings::Vector{Float64}` : Confidence scores for inferred edges. +- `infEdges::Vector{String}` : Inferred edges as `"TF,Target"` strings. +- `gsEdgesByTf::Vector{Array{String}}` : Gold standard edges grouped by TF. +- `uniGsRegs::Vector{<:AbstractString}` : Unique TFs in the gold standard. +- `gsRandPRbyTf::Vector{Float64}` : Random baseline AUPR per TF. + +# Keyword Arguments +- `saveDir::Union{String,Nothing}=nothing` : Directory to save per-TF plots. If `nothing`, plots are skipped. +- `breakTies::Bool=true` : Average indicators over tied rankings to smooth PR curves. +- `target_points::Int=1000` : Interpolation points for macro curves. +- `min_step::Float64=1e-4` : Minimum step size for interpolation. +- `step_method::Symbol=:min_gap` : Step selection method for macro metrics. +- `xLimitRecall::Float64=1.0` : : Maximum recall shown on the x-axis of diagnostic per-TF PR plots. + +# Returns +An `OrderedDict` with: +- `:perTF` → `OrderedDict` with `:gsRegs`, `:tfEdges`, `:tfEdgesVec`, `:precisions`, + `:recalls`, `:fprs`, `:randPR`, `:auprs`, `:arocs` +- `:macroPR` → `OrderedDict` with `:auprInterpolated`, `:precisions`, `:recalls` +- `:macroROC` → `OrderedDict` with `:aurocInterpolated`, `:fprs`, `:tprs` + +# Example +```julia +results = evaluatePerTF(regs, targs, rankings, infEdges, totTargGenes, + gsEdgesByTf, uniGsRegs, gsRandPRbyTf) +results[:perTF][:auprs] # per-TF AUPR values +results[:macroPR][:auprInterpolated] # macro AUPR +``` +""" + +function evaluatePerTF( + regs::Vector{<:AbstractString}, targs::Vector{<:AbstractString}, rankings::Vector{Float64}, infEdges::Vector{String}, + gsEdgesByTf::Vector{Array{String}}, uniGsRegs::Vector{<:AbstractString}, gsRandPRbyTf::Vector{Float64}; + saveDir::Union{AbstractString, Nothing} = nothing, breakTies::Bool=true, target_points=1000, min_step=1e-4, step_method=:min_gap, xLimitRecall::Float64 = 1.0) + + + if saveDir !== nothing && !isempty(saveDir) + saveDir = joinpath(saveDir, "perTF") + mkpath(saveDir) + end + + totGsRegs = length(uniGsRegs) + totTargGenes = maximum(length.(gsEdgesByTf)) # approximate max targets + + println("---- Computing Per-TF Metrics") + gsAuprsByTf = zeros(totGsRegs,1) + gsArocsByTf = zeros(totGsRegs,1) + gsPrecisionsByTf = Array{Float64}[] + gsRecallsByTf = Array{Float64}[] + gsFprsByTf = Array{Float64}[] + + allTfEdgesVec = Array{Float64}[] # one vector per TF + allTfEdges = Vector{Vector{String}}() # raw edges + allTfRanks = Array{Float64}[] # one array per TF + for (iTF, tf) in enumerate(uniGsRegs) + # Filter inferred edges for this TF + inds = findall(x -> x == tf, regs) + tfEdges = infEdges[inds] + tfRanks = rankings[inds] + tfEdgesVec = [in(edge, gsEdgesByTf[iTF]) ? 1 : 0 for edge in tfEdges] + # Tie-break if needed + if breakTies + tfAbsRanks = abs.(tfRanks) + uniqueRanks = unique(tfAbsRanks) + meanIndicator = Dict{Float64, Float64}() + for r in uniqueRanks + idxs = findall(x -> x == r, tfAbsRanks) + meanIndicator[r] = mean(tfEdgesVec[idxs]) + end + tfEdgesVec = [meanIndicator[r] for r in tfAbsRanks] + end + # Save for later + push!(allTfEdgesVec, tfEdgesVec) + push!(allTfEdges, tfEdges) + push!(allTfRanks, tfRanks) # store adjusted ranks for this TF + + # Compute cumulative metrics + tfTP = 0.0 + tfPrec = Float64[] + tfRec = Float64[] + tfFpr = Float64[] + totNegativesTF = totTargGenes - length(gsEdgesByTf[iTF]) + for j in 1:length(tfEdgesVec) + tfTP += tfEdgesVec[j] + fp = j - tfTP + push!(tfPrec, tfTP / j) + push!(tfRec, tfTP / length(gsEdgesByTf[iTF])) + push!(tfFpr, fp / totNegativesTF) + end + # Prepend starting point + if !isempty(tfPrec) + tfRec = vcat(0.0, tfRec) + tfPrec = vcat(tfPrec[1], tfPrec) + tfFpr = vcat(0.0, tfFpr) + end + + # -- Plot perTF metric + currRandPR = gsRandPRbyTf[iTF] + plotPRCurve(tfRec, tfPrec, currRandPR, saveDir; xLimitRecall=xLimitRecall, baseName = tf) + plotROCCurve(tfFpr, tfRec, saveDir; baseName = tf) + + # Save per-TF metrics + push!(gsPrecisionsByTf, tfPrec) + push!(gsRecallsByTf, tfRec) + push!(gsFprsByTf, tfFpr) + # AUPR / AUROC per TF + heightsTF = (tfPrec[2:end] + tfPrec[1:end-1]) / 2 + widthsTF = tfRec[2:end] - tfRec[1:end-1] + gsAuprsByTf[iTF] = sum(heightsTF .* widthsTF) + + widthsRocTF = tfFpr[2:end] - tfFpr[1:end-1] + heightsRocTF = (tfRec[2:end] + tfRec[1:end-1]) / 2 + gsArocsByTf[iTF] = sum(widthsRocTF .* heightsRocTF) + end + + # Compute macro metrics + + # maxLen = maximum(length.(gsPrecisionsByTf)) + # macroPrecisions = zeros(maxLen) + # macroRecalls = zeros(maxLen) + # for i in 1:maxLen + # precVals = Float64[] + # recVals = Float64[] + # for (p, r) in zip(gsPrecisionsByTf, gsRecallsByTf) + # if i <= length(p) + # push!(precVals, p[i]) + # push!(recVals, r[i]) + # end + # end + # macroPrecisions[i] = mean(precVals) + # macroRecalls[i] = mean(recVals) + # end + # macroAUPR = mean(gsAuprsByTf) + # macroAUROC = mean(gsArocsByTf) + + # A global result but requires perTF to compute + macroResults = computeMacroMetrics(gsPrecisionsByTf, gsRecallsByTf, gsFprsByTf, gsAuprsByTf, gsArocsByTf; + target_points=target_points, min_step=min_step, step_method=step_method) + + perTFDict = OrderedDict( + :gsRegs => uniGsRegs, + :tfEdges => allTfEdges, # raw inferred edges per TF + :tfEdgesVec => allTfEdgesVec, # 0/1 indicator per TF + :precisions => gsPrecisionsByTf, + :recalls => gsRecallsByTf, + :fprs => gsFprsByTf, + :randPR => gsRandPRbyTf, + :auprs => gsAuprsByTf, + :arocs => gsArocsByTf + ) + + # merge!(perTF, macroResults) + results = OrderedDict( + :perTF => perTFDict, + :macroPR => macroResults[:macroPR], + :macroROC => macroResults[:macroROC] + ) + return results +end + +""" + computePR( + gsFile::String, infTrnFile::String; + gsRegsFile::Union{String, Nothing} = nothing, targGeneFile::Union{String, Nothing} = nothing, # filtering + breakTies::Bool = true, partialAUPRlimit::Float64 = 0.1, doPerTF::Bool = true, # computation + xLimitRecall::Float64 = 1.0, saveDir::Union{String, Nothing} = nothing, # plotting + target_points::Int = 1000, min_step::Float64 = 1e-4, step_method::Symbol = :min_gap) # interpolation + +Compute precision-recall (PR) and ROC metrics for an inferred gene regulatory network (GRN) +against a gold standard, including optional per-TF and macro-level metrics. +Results are saved to a `.jld` file and PR/ROC curves are plotted automatically. + +# Arguments +- `gsFile::String` : Gold standard TSV file with columns in order: TF (1), Target (2), Weight (3) — column names can be anything +- `infTrnFile::String` : Inferred GRN TSV file with columns in order: TF (1), Target (2), Weight (3) — column names can be anything + +# Keyword Arguments +- `gsRegsFile::Union{String,Nothing}=nothing` : File listing regulators to include. Default uses all GS regulators. +- `targGeneFile::Union{String,Nothing}=nothing` : File listing target genes to include. Default uses all GS targets. +- `rankColTrn::Int=3` : Column index for scores in the inferred GRN file. +- `breakTies::Bool=true` : Average indicators over tied rankings to smooth PR curves. +- `partialAUPRlimit::Float64=0.1` : Maximum recall cutoff for computing partial AUPR. +- `xLimitRecall::Float64=1.0` : Maximum recall shown on the x-axis of diagnostic PR plots. +- `doPerTF::Bool=true` : Whether to compute per-TF and macro metrics. +- `saveDir::Union{String,Nothing}=nothing` : Directory for saving plots and results. Defaults to a timestamped folder. +- `target_points::Int=1000` : Interpolation points for macro curves. +- `min_step::Float64=1e-4` : Minimum step size for interpolation. +- `step_method::Symbol=:min_gap` : Step selection method for macro metrics. + +# Returns +An `OrderedDict` with: +- `:gsRegs`, `:gsTargs`, `:gsEdges` → Gold standard regulators, targets, and edges. +- `:randPR` → Random baseline PR value. +- `:infRegs`, `:infEdges` → Inferred GRN regulators and edges. +- `:edgesVec` → Tie-adjusted or binary indicator vector. +- `:precisions`, `:recalls`, `:fprs` → Overall PR/ROC curve arrays. +- `:auprs` → Dict with `:full` AUPR and `:partial` AUPR (up to `partialAUPRlimit`). +- `:arocs` → Overall AUROC. +- `:f1scores` → F1-score array. +- `:perTF` → Per-TF metrics dict (see `evaluatePerTF`). `nothing` if `doPerTF=false`. +- `:macroPR` → Macro PR dict. `nothing` if `doPerTF=false`. +- `:macroROC` → Macro ROC dict. `nothing` if `doPerTF=false`. +- `:savedFile` → Path to saved `.jld` results file. + +# Example +```julia +results = computePR("gs.tsv", "inferredGRN.tsv"; + breakTies=true, partialAUPRlimit=0.1, saveDir="plots") + +results[:auprs][:full] # overall AUPR +results[:auprs][:partial][:value] # partial AUPR at recall limit +results[:perTF][:auprs] # per-TF AUPRs +results[:macroPR][:auprInterpolated] # macro AUPR +``` +""" +# function computePR( +# gsFile::String, infTrnFile::String; +# gsRegsFile::Union{String, Nothing} = nothing, targGeneFile::Union{String, Nothing} = nothing, +# # rankColTrn::Int = 3, +# breakTies::Bool = true, partialLimitRecall::Float64 = 0.1, doPerTF::Bool = true, +# saveDir::Union{String, Nothing} = nothing, target_points::Int = 1000, min_step::Float64 = 1e-4, +# step_method::Symbol = :min_gap) + +infTrnFile = "/data/miraldiNB/anthony/Inferelator_Julia/outputs/251125_HAE_ISGF3_GAS_combined/combined/combined_sp.tsv" +gsFile = "/data/miraldiNB/giulia/GS_10/IFNB_A549.tsv" +targGeneFile = "/data/miraldiNB/giulia/intersection/IFNB_A549_target_genes.txt" +gsRegsFile = "/data/miraldiNB/Katko/Projects/Julia/AnthonyData/tfs.txt" + +function computePR( + gsFile::String, infTrnFile::String; + gsRegsFile::Union{String, Nothing} = nothing, targGeneFile::Union{String, Nothing} = nothing, # filtering + breakTies::Bool = true, partialAUPRlimit::Float64 = 0.1, doPerTF::Bool = true, # computation + xLimitRecall::Float64 = 1.0, saveDir::Union{String, Nothing} = nothing, # plotting + target_points::Int = 1000, min_step::Float64 = 1e-4, step_method::Symbol = :min_gap) # interpolation + + # Helper function for empty/fallback PR result + function emptyPRResult(; + gsRegs::Vector{String} = String[], + gsTargs::Vector{String} = String[], + gsEdges::Vector{String} = String[], + infRegs::Vector{String} = String[], + infEdges::Vector{String} = String[]) + return OrderedDict( + :gsRegs => gsRegs, + :gsTargs => gsTargs, + :gsEdges => gsEdges, + :infRegs => infRegs, + :infEdges => infEdges, + :precisions => Float64[], + :recalls => Float64[], + :fprs => Float64[], + :auprs => Dict( + :full => 0.0, + :partial => Dict( + :value => 0.0, + :recallLimit => partialAUPRlimit # captured from outer scope + ) + ), + :arocs => 0.0, + :f1scores => Float64[], + :perTF => nothing, + :macroPR => nothing, + :macroROC => nothing, + :savedFile => nothing + ) + end + # ----- Part 1. Load and Validate Input Files + gsData = CSV.File(gsFile; delim="\t", header=true) + trnData = CSV.File(infTrnFile; delim="\t", header=true) + + validateColumnCount(gsData, gsFile) + validateColumnCount(trnData, infTrnFile) + + # Reload selecting only first 3 columns by position + gsData = CSV.File(gsFile; delim="\t", select=[1, 2, 3]) + trnData = CSV.File(infTrnFile; delim="\t", select=[1, 2, 3]) + + + # Initialize defaults BEFORE any conditionals + potTargGenes = String[] + potTargGenesSet = Set{String}() + gsPotRegs = String[] + gsPotRegsSet = Set{String}() + + # ----- Part 2. Load Optional Filter Files (Potential Targets and Regulators) + if targGeneFile !== nothing && !isempty(targGeneFile) + potTargGenes = readlines(targGeneFile) + potTargGenesSet = Set(potTargGenes) + totTargGenes = length(unique(potTargGenes)) + @info "Target gene file loaded: $(length(potTargGenes)) genes from $targGeneFile" + else + @info "No target gene file provided — will use gold standard targets after GS is loaded" + end + + # Load regulator list if provided + if gsRegsFile !== nothing && !isempty(gsRegsFile) + gsPotRegs = readlines(gsRegsFile) + gsPotRegsSet = Set(gsPotRegs) + @info "Regulator file loaded: $(length(gsPotRegs)) regulators from $gsRegsFile" + else + @info "No regulator file provided — will use all regulators in gold standard" + end + + # ----- Part 3. Filter and Process Gold Standard + + # Filter gold standard by regulators and target genes if specified: limit to TF-gene interactions considered by the model + gsFilteredData = filter(row -> row[3] > 0, gsData) + # if gsRegsFile !== nothing && !isempty(gsRegsFile) + # # gsFilteredData = filter(row -> (row.TF in gsPotRegs) && (row.Target in potTargGenes), gsFilteredData) + # gsFilteredData = filter(row ->(row[1] in gsPotRegsSet) && (row[2] in potTargGenesSet), gsFilteredData) + # size(gsFilteredData) + # end + + # Filter by regulators if provided + if !isempty(gsPotRegsSet) + gsFilteredData = filter(row -> row[1] in gsPotRegsSet, gsFilteredData) + end + + # Filter by targets if provided + if !isempty(potTargGenesSet) + gsFilteredData = filter(row -> row[2] in potTargGenesSet, gsFilteredData) + end + + # Edge-case check: skip if no gold standard edges + if isempty(gsFilteredData) + @warn "No overlapping gold standard edges after filtering. Skipping PR/ROC evaluation." + return emptyPRResult() + end + + # Define gold standard edges (each as "TF,Target") + gsRegs = collect(row[1] for row in gsFilteredData) + gsTargs = collect(row[2] for row in gsFilteredData) + totGsInts = size(gsFilteredData)[1] + uniGsRegs = unique(gsRegs) # unique regulators or TFs in GS + totGsRegs = length(uniGsRegs) + uniGsTargs = unique(gsTargs) # unique targets in GS + totGsTargs = length(uniGsTargs) + # Create new edges vector. (Each edge is a tuple (TF, gene)) + gsEdges = [string(gsRegs[i], ",", gsTargs[i]) for i in 1:length(gsRegs)] + + # If targGeneFile wasn’t provided, use GS targets. + if targGeneFile === nothing || isempty(targGeneFile) + potTargGenes = uniGsTargs + totTargGenes = length(potTargGenes) + end + + # Compute evaluation universe and random PR baseline + gsTotPotInts = totTargGenes*totGsRegs # complete universe size + # gsTotPotInts = totGsRegs * totGsTargs + gsRandPR = totGsInts/gsTotPotInts + + # group gold standard edges by regulator + gsEdgesByTf = Array{String}[] + gsRandPRbyTf = zeros(totGsRegs) + for gind = 1:totGsRegs + currInds = findall(x -> x == uniGsRegs[gind], gsRegs) + push!(gsEdgesByTf, vec(permutedims(gsEdges[currInds]))) + gsRandPRbyTf[gind] = length(gsEdgesByTf[gind])/totGsTargs + end + + # ----- Part 4. Load and Filter Inferred GRN + @info "Loading and Processing Inferred GRN" + # Filter "inferred GRN" to include only 'TFs' in GS and 'TargetGenes' in 'potTargGenes' + grnData = filter(row -> (row[1] in uniGsRegs) && (row[2] in potTargGenes), trnData) + grnData = collect(grnData) + + if isempty(grnData) + @warn "No inferred edges after filtering. Skipping evaluation." + return emptyPRResult(gsRegs=uniGsRegs, gsTargs=uniGsTargs, gsEdges=gsEdges) + end + + # Order by absolute value of weights/confidences + grnData = sort(grnData, by = row -> abs(row[3]), rev = true) + regs = collect(row[1] for row in grnData) + targs = collect(row[2] for row in grnData) + rankings = collect(row[3] for row in grnData) + absRankings = abs.(rankings) + # Create inferred edges vector. (Each edge is a tuple (TF, gene)) + infEdges = [string(regs[i], ",", targs[i]) for i in 1:length(regs)] + totTrnInts = length(infEdges) + + # ----- Part 5. Compute Edge Indicators + @info "Computing Edge Indicators and Labels" + # create binary labels. 1 if infEgde in gsEdge and 0 otherwise + commonEdgesBinaryVec = [in(edge, gsEdges) ? 1 : 0 for edge in infEdges] + @info "Total Interactions w/ GS TFs ($(length(unique(regs)))): $totTrnInts" + + if breakTies # If breaking ties is desired, compute tie-adjusted (mean) vector. + # Break ties in weights/confidences to smooth out abrupt jumps caused by abitrary ordering of tied predictions + # Create a dictionary mapping each score to the mean value of commonEdgesVec for tied predictions + @info "Tie breaker enabled" + uniqueRankings = unique(absRankings) + meanIndicator = Dict{Float64, Float64}() + for currRank in uniqueRankings + inds = findall(x -> x == currRank, absRankings) + meanIndicator[currRank] = mean(commonEdgesBinaryVec[inds]) + end + meanEdgesVec = [meanIndicator[ix] for ix in absRankings] + edgesVec = meanEdgesVec + else # If not breaking ties, use binary vector + edgesVec = commonEdgesBinaryVec + end + + # ----- Part 6. Compute Overall PR/ROC Performance Metrics + @info "Computing Performance Metrics" + totalNegatives = gsTotPotInts - totGsInts # Total posisble interactions - length(gsEdges) = TN + FP + gsPrecisions = zeros(totTrnInts) + gsRecalls = zeros(totTrnInts) + gsFprs = zeros(totTrnInts) + + cummulativeTP = 0.0 # can be fractional in tie-adjusted mode + for idx in 1:totTrnInts + cummulativeTP += edgesVec[idx] # Add the effective contirbution for this prediction. This is also True Positive + # False positive (FP). This is a weighted FP in the case of tie breaking + falsePositives = idx - cummulativeTP + + # Compute precision: TP divided by total prediction so far + # idx is always the total predicted positive at any point j + # such that j = TP + FP + gsPrecisions[idx] = cummulativeTP / idx # + gsRecalls[idx] = cummulativeTP / totGsInts # truePositives/length(gsEdges) == tp/tp+fn + gsFprs[idx] = falsePositives / totalNegatives + end + + # Prepend starting point for plotting. + if !isempty(gsPrecisions) + gsRecalls = vcat(0.0, gsRecalls) + gsPrecisions = vcat(gsPrecisions[1], gsPrecisions) + gsFprs = vcat(0.0, gsFprs) + end + # Compute F1-scores + # gsF1scores = 2 * (gsPrecisions .* gsRecalls) ./ (gsPrecisions + gsRecalls); + gsF1scores = [p + r > 0 ? 2p*r/(p+r) : 0.0 for (p, r) in zip(gsPrecisions, gsRecalls)]; + + # ----- Part 7. Compute AUPR and AUROC + @info "Computing AUPR and AUROC" + # Here, AUPR is computed using trapezoidal rule. Other methods available is a step-function approximation + heights = (gsPrecisions[2:end] + gsPrecisions[1:end-1])/2 + widths = gsRecalls[2:end] - gsRecalls[1:end-1] + gsAuprs = sum(heights .* widths) + #= + # Step-function approximation. This works but is less robust + gsAuprs = 0.0 + prev_recall = 0.0 + for (r, p) in zip(gsRecalls, gsPrecisions) + delta = r - prev_recall + gsAuprs += delta * p + prev_recall = r + end + =# + + # PARTIAL AUPR @ partialAUPRlimit + indx = findall(r -> r <= partialAUPRlimit, gsRecalls) # Find indices of recalls <= partialUpperLimRecall + recalls_sub = gsRecalls[indx] + precisions_sub = gsPrecisions[indx] + heights_sub = (precisions_sub[2:end] + precisions_sub[1:end-1]) ./ 2 + widths_sub = recalls_sub[2:end] - recalls_sub[1:end-1] + partialAUPR = sum(heights_sub .* widths_sub) + + # AROC : Trapezoidal Rule + widthsRoc = gsFprs[2:end] - gsFprs[1:end-1] # Change in FPR (the x-axis) between successive points. + heightsRoc = (gsRecalls[2:end] + gsRecalls[1:end-1]) / 2 # Average TPR (recall) for each segment. + gsArocs = sum(widthsRoc .* heightsRoc) + #= + gsArocs = 0.0 + prev_fpr = 0.0 + for (f, r) in zip(fprs, recalls) + delta = f - prev_fpr + gsArocs += delta * r + prev_fpr = f + end + =# + + # ----- Part 8. Save Directory + baseName = splitext(basename(infTrnFile))[1] + if saveDir === nothing || isempty(saveDir) + dateStr = Dates.format(now(), "yyyymmdd_HHMMSS") + saveDir = joinpath(pwd(), baseName * "_" * dateStr) + end + mkpath(saveDir) + + # ----- Part 9. Per-TF and Macro Metrics + resultsTF = nothing + perTFDict = nothing + macroPRDict = nothing + macroROCDict = nothing + if doPerTF + if !isempty(regs) && !isempty(targs) && !isempty(gsEdgesByTf) + resultsTF = evaluatePerTF( + regs, targs, rankings, infEdges, gsEdgesByTf, uniGsRegs, gsRandPRbyTf; + saveDir, + breakTies=breakTies, target_points=target_points, min_step=min_step, + step_method=step_method, xLimitRecall=xLimitRecall + ) + # Extract Per-TF Results + perTFDict = resultsTF[:perTF] + macroPRDict = resultsTF[:macroPR] + macroROCDict = resultsTF[:macroROC] + else + @warn "Skipping per-TF evaluation: no valid edges after filtering" + end + end + + + # ----- Part 10. Plot and Save Results + suffix = breakTies ? "_tiesBroken" : "" + baseName = baseName * suffix + plotPRCurve(gsRecalls, gsPrecisions, gsRandPR, saveDir; xLimitRecall=xLimitRecall, baseName = baseName) + plotROCCurve(gsFprs, gsRecalls, saveDir; baseName = baseName) + + + results = OrderedDict( + :gsRegs => uniGsRegs, + :gsTargs => uniGsTargs, + :gsEdges => gsEdges, + :randPR => gsRandPR, + :infRegs => regs, + :infEdges => infEdges, + :breakTies => breakTies, + :stepVals => rankings, + :rankings => rankings, + # :commonEdgesBinaryVec => commonEdgesBinaryVec, + # :meanEdgesVec => breakTies ? meanEdgesVec : nothing, + :edgesVec => edgesVec, + :precisions => gsPrecisions, + :recalls => gsRecalls, + :fprs => gsFprs, + # :auprs => gsAuprs, + :auprs => Dict( + :full => gsAuprs, + :partial => Dict( + :value => partialAUPR, + :recallLimit => partialAUPRlimit + ) + ), + :arocs => gsArocs, + :f1scores => gsF1scores, + # --- Include results from evaluatePerTF --- + :perTF => perTFDict, + :macroPR => macroPRDict, + :macroROC => macroROCDict + ) + + # Define a standard filename for saving the performance metrics + savedFile = joinpath(saveDir, baseName * "_PerformanceMetric.jld") + @save savedFile results + + # Add the saved file path to the results, so the caller immediately knows it + results[:savedFile] = savedFile + + return results +end \ No newline at end of file diff --git a/src/metrics/Constants.jl b/src/metrics/Constants.jl new file mode 100755 index 0000000..8314a21 --- /dev/null +++ b/src/metrics/Constants.jl @@ -0,0 +1,9 @@ +# const GS_REQUIRED_COLS = [:TF, :Target, :Weight] +# const TRN_REQUIRED_COLS = [:TF, :Target, :Score] +const MIN_REQUIRED_COLS = 3 + +const DEFAULT_COLORS = [ + "#377eb8", "#ff7f00", "#4daf4a", "#e41a1c", "#984ea3", "#a65628", + "#f781bf", "#00ced1", "#000000", "#5A9D5A", "#D96D3B", "#FFAD12", + "#66628D", "#91569A", "#B6742A", "#DD87B4", "#D26D7A", "#DAA520", "#dede00" +] \ No newline at end of file diff --git a/src/metrics/MetricUtils.jl b/src/metrics/MetricUtils.jl new file mode 100755 index 0000000..cf53659 --- /dev/null +++ b/src/metrics/MetricUtils.jl @@ -0,0 +1,96 @@ +""" + adaptiveStep(xs; target_points=1000, min_step=1e-4, method=:min_gap) + +Select a step size for interpolation based on actual values in `xs`. + +# Arguments +- `xs`: Array of arrays (per TF) or single array of values. +- `target_points`: Desired number of points if `method=:target_points`. +- `min_step`: Minimum allowable step size. +- `method`: `:min_gap` uses smallest nonzero gap; `:target_points` uses range/target_points. +""" +function adaptiveStep(xs; target_points=1000, min_step=1e-4, method=:min_gap) + vals = isa(xs[1], AbstractArray) ? reduce(vcat, xs) : xs + vals = sort(unique(vals)) + + if method == :min_gap + gaps = diff(vals) + nonzero_gaps = filter(g -> g > 0, gaps) + step = isempty(nonzero_gaps) ? min_step : max(minimum(nonzero_gaps), min_step) + elseif method == :target_points + range_span = maximum(vals) - minimum(vals) + step = max(range_span / target_points, min_step) + else + error("Unknown method: $method") + end + + return step +end + + +""" + dedupNInterpolate(xs, ys, interpPts) + +Linearly interpolate `ys` onto `interpPts`, handling duplicates and invalid values. + +Filters `NaN`/`Inf`, sorts by `xs`, collapses duplicate `xs` by keeping max `y`, +then interpolates using `Flat()` extrapolation. Returns zeros if fewer than 2 valid points. +""" +function dedupNInterpolate(xs::Vector{Float64}, ys::Vector{Float64}, interpPts::AbstractVector{Float64}) + # Filter out invalid values + valid_inds = isfinite.(xs) .& isfinite.(ys) + xs = xs[valid_inds] + ys = ys[valid_inds] + + # Check if there are enough points to interpolate + if length(xs) < 2 + return zeros(length(interpPts)) + end + + # Sort xs and corresponding ys + perm = sortperm(xs) + xsSorted = xs[perm] + ysSorted = ys[perm] + + # Collapse duplicates: keep max y for each unique x + dict = Dict{Float64, Float64}() + for (x, y) in zip(xsSorted, ysSorted) + dict[x] = haskey(dict, x) ? max(dict[x], y) : y + end + + xsUnique = sort(collect(keys(dict))) + ysUnique = [dict[x] for x in xsUnique] + + # Handle edge case: all ys are identical (LinearInterpolation still works) + itp = LinearInterpolation(xsUnique, ysUnique, extrapolation_bc=Flat()) + return itp.(interpPts) +end + +""" + validateColumns(cols, required, filename) + +Check that all `required` column names are present in `cols`. +Throws an informative error naming the missing columns if any are absent. + +usage: +validateColumns(Symbol.(propertynames(gsData)), GS_REQUIRED_COLS, gsFile) +""" +# function validateColumns(cols, required, filename) +# missing_cols = setdiff(required, cols) +# if !isempty(missing_cols) +# error(""" +# Missing columns in $filename: $(join(missing_cols, ", ")) +# Required columns: $(join(required, ", ")) +# Found columns: $(join(cols, ", ")) +# """) +# end +# end + +function validateColumnCount(data, filename) + if length(propertynames(data)) < MIN_REQUIRED_COLS + error(""" + $filename must have at least 3 columns in order: TF, Target, Score + Found: $(length(propertynames(data))) column(s) + """) + end +end \ No newline at end of file diff --git a/src/metrics/Metrics.jl b/src/metrics/Metrics.jl new file mode 100755 index 0000000..80324b6 --- /dev/null +++ b/src/metrics/Metrics.jl @@ -0,0 +1,33 @@ +using JLD2, InlineStrings +using PyPlot +using Colors +using Dates +using DataFrames +using OrderedCollections +using Measures +using CSV +using Base.Threads +using Interpolations +using Statistics + +# ---------------------------- +# Plot defaults +# ---------------------------- + +const dpi = 600 + +function setPlotDefaults!() + rc = PyPlot.matplotlib.rcParams + + rc["font.family"] = "Nimbus Sans" + rc["axes.titlesize"] = 9 + rc["axes.labelsize"] = 9 + rc["xtick.labelsize"] = 7 + rc["ytick.labelsize"] = 7 + rc["legend.fontsize"] = 9 + + rc["figure.dpi"] = dpi + rc["savefig.dpi"] = dpi + + return nothing +end \ No newline at end of file diff --git a/src/metrics/plotting/PlotBatch.jl b/src/metrics/plotting/PlotBatch.jl new file mode 100755 index 0000000..4d1946f --- /dev/null +++ b/src/metrics/plotting/PlotBatch.jl @@ -0,0 +1,928 @@ +# ---------------------------------------------------------------------- +# Color Handling +# ---------------------------------------------------------------------- + +""" + padColors(listFiles) + +Ensures enough colors for plotting by extending a predefined palette. +Returns a vector of hex color codes, generating random colors if needed. + +# Arguments +- `listFiles`: Vector of items needing unique colors. +""" +function padColors(listFiles) + lineColors = copy(DEFAULT_COLORS) + lenFile = length(listFiles) + lenColor = length(lineColors) + excessCT = lenFile - lenColor + + if excessCT > 0 + println("Warning: There are more entries in listFilePR than available colors. Generating random colors for the excesses") + colorGen = [hex(RGB(rand(), rand(), rand())) for _ in 1:(excessCT + 12)] # Generate random colors for plots + uniqueCols = setdiff(colorGen, lineColors) + # Pad lineColors with the unique additional colors + lineColors = vcat(lineColors, uniqueCols[1:excessCT]) + end + return lineColors +end + +# ---------------------------------------------------------------------- +# Helper: Data Loading +# ---------------------------------------------------------------------- +# function loadPRData(source) +# if isa(source, String) +# try +# data = load(source) +# return Dict( +# :precisions => data["results"][:precisions], +# :recalls => data["results"][:recalls], +# :randPR => get(data["results"], :randPR, nothing) +# ) +# catch e +# @warn "Could not load $source" exception=(e, catch_backtrace()) +# return nothing +# end +# elseif isa(source, Dict) || isa(source, OrderedDict) +# if haskey(source, :precisions) && haskey(source, :recalls) +# return Dict(:precisions=>source[:precisions], :recalls=>source[:recalls], +# :randPR=>get(source, :randPR, nothing)) +# elseif haskey(source, "results") +# r = source["results"] +# return Dict(:precisions=>r[:precisions], :recalls=>r[:recalls], +# :randPR=>get(r, :randPR, nothing)) +# end +# end +# return nothing +# end + +function loadPRData(source; mode::Symbol=:global) + # Load file safely if source is a path + r = if isa(source, String) + try + load(source)["results"] + catch e + @warn "Could not load $source" exception=(e, catch_backtrace()) + return nothing + end + else + source + end + + if r === nothing + return nothing + end + + # Determine mode + if mode == :macro + if !haskey(r, :macroPR) + @warn "No macroPR found, falling back to global" + r_macro = r + else + r_macro = r[:macroPR] + end + return Dict( + :precisions => r_macro[:precisions], + :recalls => r_macro[:recalls], + :randPR => get(r_macro, :randPR, get(r, :randPR, nothing)) # fallback to global randPR + ) + elseif mode == :global + return Dict( + :precisions => r[:precisions], + :recalls => r[:recalls], + :randPR => get(r, :randPR, nothing) + ) + elseif mode == :perTF + if !haskey(r, :perTF) + @warn "No perTF data found" + return nothing + end + return r[:perTF] + else + error("Unknown mode: $mode") + end +end + +# ---------------------------------------------------------------------- +# Transformation Helper +# ---------------------------------------------------------------------- +""" + function makeTransform(yScale::String) +Creates forward and inverse transformation functions for plot scaling. + +# Arguments +- `yScale::String`: Transformation type ("sqrt", "cubert", or "linear") + +Returns a tuple of (forward, inverse) functions. Defaults to identity functions if scale type not recognized. + """ +function makeTransform(yScale::String) + if yScale == "linear" + forwardFunc = x -> x + inverseFunc = x -> x + elseif yScale == "sqrt" + forwardFunc = x -> sqrt.(x) + inverseFunc = x -> x .^ 2 + elseif yScale == "cubert" + forwardFunc = x -> x .^ (1/3) + inverseFunc = x -> x .^ 3 + else + # For "linear" or any unknown yScale, return identity functions + @warn "Unknown scale type: $yScale. Using linear scale." + forwardFunc = x -> x + inverseFunc = x -> x + end + return forwardFunc, inverseFunc +end + +# ---------------------------------------------------------------------- +# Plotting Helpers +# ---------------------------------------------------------------------- + +# # Helper function: Plot data on one or more axes. +# function plotFileData!(axes, legendLabel::String, dataSource, color; lineType=nothing) +# """ +# Plots precision-recall curves on specified axes from data file or direct data. + +# # Arguments +# - `axes`: Array of plot axes to draw on +# - `legendLabel::String`: Label for plot legend +# - `dataSource`: Either a file path (String) or a Dict containing PR data +# - `color`: Color for plotting +# - `lineType=nothing`: Optional line style specification + +# # Returns +# Loaded/processed data or nothing if processing fails. +# """ + +# # Initialize variables to hold the data +# precisions = nothing +# recalls = nothing +# randPR = nothing +# fileData = nothing + +# # Process the data source based on its type +# if isa(dataSource, String) +# # It's a file path - try to load it +# try +# fileData = load(dataSource) +# # Extract data from the loaded file +# precisions = fileData["results"][:precisions] +# recalls = fileData["results"][:recalls] +# randPR = get(fileData["results"], :randPR, nothing) +# catch e +# println("Error loading file: $dataSource - $e") +# return nothing +# end +# elseif isa(dataSource, Dict) || isa(dataSource, OrderedDict) +# # It's already a data dictionary +# fileData = dataSource + +# # Check if it's a results dictionary or a direct data dictionary +# if haskey(dataSource, :precisions) && haskey(dataSource, :recalls) +# # Direct data format +# precisions = dataSource[:precisions] +# recalls = dataSource[:recalls] +# randPR = get(dataSource, :randPR, nothing) +# elseif haskey(dataSource, "results") +# # Nested results format (like from a loaded JLD file) +# precisions = dataSource["results"][:precisions] +# recalls = dataSource["results"][:recalls] +# randPR = get(dataSource["results"], :randPR, nothing) +# else +# println("Error: Data dictionary does not contain required precision/recall data") +# return nothing +# end +# else +# println("Error: Unsupported data source type: $(typeof(dataSource))") +# return nothing +# end + +# # Ensure we have valid data before plotting +# if isnothing(precisions) || isnothing(recalls) +# println("Error: Could not extract precision/recall data from source") +# return nothing +# end + +# # Plot the data on each provided axis +# for ax in axes +# if isnothing(lineType) || lineType == "" +# # Use default linestyle +# ax.plot(recalls, precisions, label=legendLabel, color=color, linewidth=0.8) +# else +# # Specify the provided line type (linestyle) +# ax.plot(recalls, precisions, label=legendLabel, color=color, linewidth=0.5, linestyle=lineType) +# end +# end + +# return Dict( +# # "data" => fileData, +# "precisions" => precisions, +# "recalls" => recalls, +# "randPR" => randPR +# ) +# end + +function plotFileData!(axes::AbstractVector, source, label::String, color::String; lineType=nothing, lineWidth=nothing, mode::Symbol=:global) + data = loadPRData(source; mode) + if isnothing(data) + @warn "Skipping $label ($source) — no valid PR data found" + return false + end + + for ax in axes + ax.plot(data[:recalls], data[:precisions], label=label, color=color,linewidth=(lineWidth === nothing ? 0.7 : lineWidth), + linestyle=(lineType === nothing ? "-" : lineType)) + end + return data +end + + +# Function to combine legend handles and labels from every axis. +function combineLegends(fig) + allHandles = Any[] + allLabels = String[] + # Loop over each axis in the figure. + for ax in fig.axes + handles, labels = ax.get_legend_handles_labels() + for (h, l) in zip(handles, labels) + if !(l in allLabels) + push!(allHandles, h) + push!(allLabels, l) + end + end + end + return allHandles, allLabels +end + +function styleAxis!(ax; xZoom=0.1, yRange::Union{Nothing,Tuple}=nothing, + xStep=0.05, yStep=0.1, yScale="linear", setXticks=true, tickLabelSize::Int = 7) +# function styleAxis!(ax; xZoom, yRange, xStep, yStep, yScale, setXticks, tickLabelSize) + # X-axis + ax.set_xlim(0, xZoom) + if setXticks + xMax = mod(xZoom, xStep) == 0 ? xZoom : ceil(xZoom/xStep) * xStep + ax.set_xticks(0:xStep:xMax) + end + + # Y-axis + forwardFunc, inverseFunc = makeTransform(yScale) + yMin, yMax = isnothing(yRange) ? (0.0, 1.0) : yRange + ax.set_ylim(yMin, yMax) + + # Y ticks + yTicks = yMin:yStep:yMax + ax.set_yticks(yTicks) + ax.set_yticklabels(string.(round.(yTicks; digits=2))) + ax.set_yscale("function", functions=(forwardFunc, inverseFunc)) + + # Labels + # ax.set_xlabel(xlabel, fontsize=axisTitleSize, family="Nimbus Sans") + # ax.set_ylabel(ylabel, fontsize=axisTitleSize, family="Nimbus Sans") + + # Grid & ticks + ax.grid(true, which="major", linestyle="-", linewidth=0.5, color="lightgray") + ax.minorticks_on() + ax.grid(true, which="minor", linestyle=":", linewidth=0.25, color="lightgray") + ax.tick_params(axis="both", which="both", labelsize=tickLabelSize, direction="out") +end + +""" + plotPRCurves(listFilePR, dirOut, saveName; kwargs...) + +Plot and save precision-recall curves for one or more networks against a gold standard. +Supports single-axis and broken y-axis modes. Saves both legend and no-legend versions. + +# Arguments +- `listFilePR` : `OrderedDict` mapping legend labels to either file paths or + `OrderedDict`s with `:precisions`, `:recalls`, and `:randPR` keys. +- `dirOut::String` : Output directory for saved plots. +- `saveName::String`: Base name for output files. + +# Keyword Arguments +- `xLimitRecall::Float64=0.1` : Maximum recall shown on x-axis. +- `yZoomPR::Vector=[]` : Y-axis display range. + `[]` → full range 0–1. + `[0.8]` → limit y-axis to 0–0.8. + `[0.4, 0.9]` → broken y-axis with gap between 0.4 and 0.9. +- `xStepSize::Union{Nothing,Real}=nothing` : X-axis tick step size. Default auto-set to 0.05. +- `yStepSize::Union{Nothing,Real}=nothing` : Y-axis tick step size. Default auto-set to 0.2. +- `yScale::String="linear"` : Y-axis scale — `"linear"`, `"sqrt"`, or `"cubert"`. +- `isInside::Bool=true` : Place legend inside (`true`) or outside (`false`) the plot. +- `lineColors::Vector=[]` : Custom line colors. Empty uses default palette. +- `lineTypes::Vector=[]` : Custom line styles. Empty uses solid lines. +- `lineWidths::Vector=[]` : Custom line widths. Empty uses default width. +- `heightRatios=nothing` : Height ratios for broken y-axis panels `[topHeight, bottomHeight]`. + `nothing` --> auto-computed proportionally from `yZoomPR` values. + Pass explicit values e.g. `[3.0, 0.5]` to control panel sizes manually — + larger value = taller panel. +- `mode::Symbol=:global` : PR data to use: + `:global` → overall PR curve. + `:macro` → macro-averaged PR (falls back to global if missing). + `:perTF` → per-TF PR curves. + +# Returns +Path to the saved plot file. + +# Example +```julia +plotPRCurves(listFilePR, "plots", "MyNetwork"; + xLimitRecall=0.1, yZoomPR=[0.4, 0.9], + isInside=false, mode=:macro) +``` +""" +function plotPRCurves(listFilePR, dirOut::String, saveName::String; + xLimitRecall = 0.1, yZoomPR = [], xStepSize::Union{Nothing,Real}=nothing, + yStepSize::Union{Nothing,Real}=nothing, + yScale::String="linear", isInside::Bool=true, + lineColors=[], # empty vector means "use default" + lineTypes=[], # empty vector means "use default", + lineWidths = [], + heightRatios=[0.5, 3.0], + mode::Symbol = :global + ) + + # Defaults + xStep = isnothing(xStepSize) ? 0.05 : xStepSize + yStep = isnothing(yStepSize) ? 0.2 : yStepSize + dirOut = isempty(dirOut) ? pwd() : mkpath(dirOut) + # Style parameters + axisTitleSize = 9 + tickLabelSize = 7 + # plotTitleSize = 16 + legendSize = 9 + + # Color assignment + if isempty(lineColors) + lineColors = padColors(listFilePR) + end + + # Auto height ratios if broken axis + if length(yZoomPR) == 2 && heightRatios === nothing + lowerSpan = max(yZoomPR[1] - 0.0, 0.0) # bottom panel span (0 -> yZoomPR[1]) + upperSpan = max(1.0 - yZoomPR[2], 0.0) # top panel span (yZoomPR[2] -> 1) + # protect against zero spans + if upperSpan == 0.0 + upperSpan = 1e-6 + end + if lowerSpan == 0.0 + lowerSpan = 1e-6 + end + # gridspec expects [top_height, bottom_height] + heightRatios = [upperSpan, lowerSpan] + end + # lastPlotData = nothing # will store last loaded file data (for randPR) + # println(">>> Inside plotPRCurves: yZoomPR = ", yZoomPR, " length=", length(yZoomPR)) + + ## Making Plots + if isempty(yZoomPR) || length(yZoomPR) == 1 + yZoom = isempty(yZoomPR) ? 1.0 : yZoomPR[1] + + # Create a Single PyPlot figure + fig, ax = subplots(figsize= isInside ? (2, 2) : (2, 2), layout="constrained") #5,4 + styleAxis!(ax; xZoom=xLimitRecall, yRange=(0,yZoom), xStep=xStep, yStep=yStep, yScale=yScale) + + @info "Plotting PR curves" + lastPlotData = nothing + for (idx, (legendLabel, currFilePR)) in enumerate(listFilePR) + @info "Plot $idx: $legendLabel" file=currFilePR + # If lineType is provided and nonempty, then pick its element if available. + currentLineType = (length(lineTypes) ≥ idx && lineTypes[idx] != "") ? lineTypes[idx] : nothing + currentLineWidth = (length(lineWidths) ≥ idx &&lineWidths[idx] != "") ? lineWidths[idx] : nothing + lastPlotData = plotFileData!([ax], currFilePR, legendLabel, lineColors[idx]; + lineType=currentLineType, lineWidth = currentLineWidth, mode) + end + + # Plot random PR line from the last file if available + if lastPlotData !== nothing && lastPlotData[:randPR] !== nothing + randPR = lastPlotData[:randPR] + ax.axhline(randPR, linestyle="-.", linewidth=2, color=[0.6, 0.6, 0.6], label="Random") + end + + # # Set labels and legend + ax.set_xlabel("Recall", fontsize=axisTitleSize, family="Nimbus Sans") + ax.set_ylabel("Precision", fontsize=axisTitleSize, family="Nimbus Sans") + # save figure without legend + if saveName !== nothing && !isempty(saveName) + savePath = joinpath(dirOut, string(saveName, "_PR_noLegend.pdf")) + else + savePath = joinpath(dirOut, string(Dates.format(now(), "yyyymmdd_HHMMSS"), "_PR_noLegend.pdf")) + end + PyPlot.savefig(savePath, dpi=600) + + # Add Figure Legends + if isInside + ax.legend(borderaxespad=0.2, frameon=true) + else + ax.legend(loc="center", bbox_to_anchor=(1.25, 0.5), borderaxespad=0.2, frameon=false) + fig.set_size_inches(3.5, 2.5) # wider only for outside legend + end + savePathLegend = replace(savePath, "PR_noLegend.pdf" => "PR.pdf") + PyPlot.savefig(savePathLegend, dpi=600) + + elseif length(yZoomPR) == 2 + # # ---- Two-subplot (broken y-axis) mode ---- + fig, (ax1, ax2) = subplots(2, 1, sharex=true, figsize=(2,2), gridspec_kw=Dict("heightRatios" => heightRatios, + "hspace" => 0.0), layout="constrained") #6,5 + + # Upper axis + styleAxis!(ax1; xZoom=xLimitRecall, yRange=(yZoomPR[2],1.0), xStep=xStep, yStep=yStep, yScale=yScale, setXticks=false) + # Lower axis + styleAxis!(ax2; xZoom=xLimitRecall, yRange=(0,yZoomPR[1]), xStep=xStep, yStep=yStep, yScale=yScale) + ax2.set_xlabel("Recall", fontsize=9) + + # Plot curves + lastPlotData = nothing + @info "Plotting PR curves" + for (idx, (legendLabel, currFilePR)) in enumerate(listFilePR) + @info "Plot $idx: $legendLabel; File: $currFilePR" + # If lineType is provided and nonempty, then pick its element if available. + currentLineType = (length(lineTypes) ≥ idx && lineTypes[idx] != "") ? lineTypes[idx] : nothing + currentLineWidth = (length(lineWidths) ≥ idx &&lineWidths[idx] != "") ? lineWidths[idx] : nothing + lastPlotData = plotFileData!([ax1, ax2], currFilePR, legendLabel, lineColors[idx]; + lineType=currentLineType, lineWidth = currentLineWidth, mode) + + end + + # Extract randPR from the last file + # Plot the random PR line on both axes if available. + if lastPlotData !== nothing && lastPlotData[:randPR] !== nothing + randPR = lastPlotData[:randPR] + # println(randPR) + for ax in (ax1, ax2) + ax.axhline(randPR, linestyle="-.", linewidth=2, color=[0.6, 0.6, 0.6], label="Random") + end + end + + # Hide the spines between ax1 and ax2 + ax1.spines["bottom"].set_visible(false) + ax2.spines["top"].set_visible(false) + ax1.tick_params(axis="x", which="both", bottom=false, top=false, labelbottom=false, labeltop=false) + ax2.xaxis.tick_bottom() + + # Add break indicators + d = 0.006 # Size of the diagonal lines in axes coordinates + # Top-left and top-right diagonals for ax1 + ax1.plot([-d, +d], [-d, +d]; transform=ax1.transAxes, color="k", clip_on=false) # Top-left diagonal + ax1.plot([1 - d, 1 + d], [-d, +d]; transform=ax1.transAxes, color="k", clip_on=false) # Top-right diagonal + # Bottom-left and bottom-right diagonals for ax2 + ax2.plot([-d, +d], [1 - d, 1 + d]; transform=ax2.transAxes, color="k", clip_on=false) # Bottom-left diagonal + ax2.plot([1 - d, 1 + d], [1 - d, 1 + d]; transform=ax2.transAxes, color="k", clip_on=false) # Bottom-right diagonal + + # Common labels: a shared y-axis label and x-axis label on the lower plot. + # fig.text(0.01, 0.5, "Precision", va="center", rotation="vertical", fontsize=axisTitleSize) + fig.supylabel("Precision", fontsize=axisTitleSize) + ax2.set_xlabel("Recall", fontsize=axisTitleSize) + + + # save figure without legend + if saveName !== nothing && !isempty(saveName) + savePath = joinpath(dirOut, string(saveName, "_PR_noLegend.pdf")) + else + savePath = joinpath(dirOut, string(Dates.format(now(), "yyyymmdd_HHMMSS"), "_PR_noLegend.pdf")) + end + PyPlot.savefig(savePath, dpi=600) + + # Add Figure Legends + if isInside + ax2.legend(fontsize= legendSize, borderaxespad=0.2, frameon=true) + else + ax2.legend(fontsize= legendSize, loc="center", bbox_to_anchor=(1.25, 0.5), borderaxespad=0.2, frameon=false) + fig.set_size_inches(5, 4) # wider only for outside legend + end + + savePathLegend = replace(savePath, "PR_noLegend.pdf" => "PR.pdf") + PyPlot.savefig(savePathLegend, dpi=600) + + end + PyPlot.close("all") +end + + +# ------------------------------------------------------------------------------------- +# Part 2: Making a DotPlot or BarPlot of AUPR using PyPlot +# ------------------------------------------------------------------------------------- +function plotAUPR(gsParam::OrderedDict{String,<:AbstractDict}, dirOut::String; + saveName::Union{Nothing, String} = nothing, + metricType::String="full", + figSize::Tuple{Real,Real}=(5,5), + axisTitleSize::Int=9, tickLabelSize::Int=7, + legendFontSize::Int=9, tickRotation::Int=45, + plotType::String="", + saveLegend::Bool=true) + + # ---------------------- + # Configure matplotlib + # rc = PyPlot.matplotlib["rcParams"] + # rc["font.family"] = "Nimbus Sans" + # rc["axes.labelsize"] = axisTitleSize + # rc["xtick.labelsize"] = tickLabelSize + # rc["ytick.labelsize"] = tickLabelSize + # rc["legend.fontsize"] = legendFontSize + + + # Ensure output directory exists + dirOut = joinpath(dirOut, "AUPR") + mkpath(dirOut) + + # Load AUPR values + function loadAupr(filePath::String, metricType::String) + try + data = load(filePath) + auprs = data["results"][:auprs] + return lowercase(metricType) == "partial" ? auprs[:partial][:value] : auprs[:full] + catch e + @warn "Could not load AUPR from $filePath" exception=(e, catch_backtrace()) + return nothing + end + end + + lineColors = copy(DEFAULT_COLORS) + # Build dataframe + gsNames = collect(keys(gsParam)) + netNames = unique(vcat([collect(keys(files)) for files in values(gsParam)]...)) + + dfRows = Vector{NamedTuple{(:xGroups,:Network,:AUPR),Tuple{String,String,Float64}}}() + for gs in gsNames, net in netNames + filePath = haskey(gsParam[gs], net) ? gsParam[gs][net] : nothing + # And update the caller to skip nothing values + auprVal = filePath === nothing ? nothing : loadAupr(filePath, metricType) + push!(dfRows, (xGroups=gs, Network=net, AUPR=something(auprVal, NaN))) # NaN is visible in plot + # push!(dfRows, (xGroups=gs, Network=net, AUPR=auprVal)) + end + + df = DataFrame(dfRows) + numGS = length(unique(df.xGroups)) + numNet = length(unique(df.Network)) + + # Decide plot type automatically if user hasn't overridden + autoPlotType = "" + if lowercase(plotType) in ["bar","dot"] + autoPlotType = lowercase(plotType) + else + autoPlotType = (numGS == 1 || numNet == 1) ? "dot" : "bar" + end + + xLabelVal = (numGS == 1 ? "Network" : "Gold Standards") + yLabelVal = lowercase(metricType) == "partial" ? "Partial AUPR" : "AUPR" + + # Function to generate a plot + function make_plot(includeLegend::Bool) + fig, ax = plt.subplots(figsize=figSize, layout="constrained") + colors = lineColors[1:numNet] + + if autoPlotType == "bar" + if numGS == 1 + for (i, net) in enumerate(netNames) + idxs = findall(df.Network .== net) + ax.bar.(i, df.AUPR[idxs], color=colors[i], label=(includeLegend ? net : nothing)) + end + ax.set_xticks(1:numNet) + ax.set_xticklabels(netNames, rotation=tickRotation) + elseif numNet == 1 + ax.bar.(1:numGS, df.AUPR, color=colors[1], label=(includeLegend ? netNames[1] : nothing)) + ax.set_xticks(1:numGS) + ax.set_xticklabels(gsNames, rotation=tickRotation) + else + barWidth = 0.15 + groupWidth = numNet * barWidth + groupSpacing = 0.4 + groupPositions = [i*(groupWidth + groupSpacing) for i in 0:(numGS-1)] + for j in 1:numNet + offset = (j-(numNet+1)/2)*barWidth + xPos = [gp + groupWidth/2 + offset for gp in groupPositions] + auprs = [df.AUPR[(i-1)*numNet + j] for i in 1:numGS] + ax.bar(xPos, auprs, barWidth, color=colors[j], label=(includeLegend ? netNames[j] : nothing)) + end + ax.set_xticks([gp + groupWidth/2 for gp in groupPositions]) + ax.set_xticklabels(gsNames, rotation=tickRotation) + end + elseif autoPlotType == "dot" + if numGS == 1 + for (i, net) in enumerate(netNames) + idxs = findall(df.Network .== net) + ax.scatter.(i, df.AUPR[idxs], color=colors[i], s=40, label=(includeLegend ? net : nothing)) + end + ax.set_xticks(1:numNet) + ax.set_xticklabels(netNames, rotation=tickRotation) + elseif numNet == 1 + ax.scatter.(1:numGS, df.AUPR, color=colors[1], s=40, label=(includeLegend ? netNames[1] : nothing)) + ax.set_xticks(1:numGS) + + ax.set_xticklabels(gsNames, rotation=tickRotation) + else + for j in 1:numNet + xvals = 1:numGS + yvals = [df.AUPR[(i-1)*numNet + j] for i in 1:numGS] + ax.scatter.(xvals, yvals, color=colors[j], s=40, label=(includeLegend ? netNames[j] : nothing)) + end + ax.set_xticks(1:numGS) + ax.set_xticklabels(gsNames, rotation=tickRotation) + end + end + + ax.set_xlabel(xLabelVal, fontsize=axisTitleSize) + ax.set_ylabel(yLabelVal, fontsize=axisTitleSize) + ax.tick_params(axis="both", which="both") + # ax.tick_params(axis="both", which="both", labelsize=tickLabelSize, direction="in") + ax.grid(true, which="major", linestyle="-", linewidth=0.5, color="lightgray") + if includeLegend + # ax.legend(fontsize=legendFontSize, loc="best") + ax.legend(fontsize=legendFontSize, loc="center", bbox_to_anchor=(1.25, 0.5), borderaxespad=0.2, frameon=false) + else + # Remove x-axis labels and ticks when no legend + ax.set_xticks([]) + ax.set_xticklabels([]) + ax.set_xlabel("") + end + + return fig, ax + end + + saveName = (saveName === nothing || isempty(saveName)) ? "AUPR_$(Dates.format(now(), "yyyymmdd_HHMMSS"))" : saveName + # Save plot according to saveLegend + savePath = joinpath(dirOut, saveName * "_$(autoPlotType)_" * (saveLegend ? "withLegend" : "noLegend") * ".pdf") + fig, _ = make_plot(saveLegend) + fig.savefig(savePath, dpi=600) + plt.close(fig) + println("Saved plot to:\n $savePath") +end + + +# function plotAUPR(gsParam::OrderedDict{String,<:AbstractDict}, dirOut::String, saveName::String; figSize::Tuple{Real, Real}) + +# """ +# Generate a visualization of Area Under the Precision-Recall Curve (AUPR) values using either a bar plot or a dot plot, depending on the input data structure. +# - If there are multiple gold standards and networks, a bar plot is created, with each group of bars representing a gold standard and each bar within a group representing a network. +# - If there is only one gold standard or one network, a dot plot is created, with each dot representing a network or gold standard. +# - The plot is saved as a PDF file in the specified `dirOut` directory, with a resolution of 600 DPI. + + +# # Arguments +# - `gsParam::OrderedDict{String,Dict{String,String}}`: Mapping of gold standards to network file paths. +# - `dirOut::String`: Directory to save the plot. +# - `saveName::String`: Base name for the output file. + +# # Keywords +# - `figSize::Tuple{Real, Real}=(5, 5)`: Size of the figure. + +# # Returns +# Nothing, but saves the plot to a file. +# """ + +# function loadAupr(filePath) +# try +# data = load(filePath) +# return data["results"][:auprs] +# catch e +# println("Error loading $filePath: $e") +# return nothing +# end +# end + +# # Check if dirOut is either nothing or an empty string +# dirOut = if dirOut === nothing || isempty(dirOut) +# pwd() +# else +# mkpath(dirOut) +# end + +# if isempty(figSize) || length(figSize) == 1 +# println("Warning: figSize must be a tuple of 2 element") +# figSize = (5,5) +# end + +# axisTitleSize = 16 +# tickLabelSize = 14 +# plotTitleSize = 16 +# legendSize = 10 + +# lineColors = [ #Initial Color palette +# "#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", +# "#984ea3", "#e41a1c", "#00ced1", "#000000", +# "#5A9D5A", "#D96D3B", "#FFAD12", "#66628D", "#91569A", +# "#B6742A", "#DD87B4", "#D26D7A", "#dede00"] + +# # Load aupr values +# gsNames = collect(keys(gsParam)) +# netNames = unique(vcat([collect(keys(files)) for files in values(gsParam)]...)) +# # auprValues = [loadAupr(gsParam[gs][netName]) !== nothing ? loadAupr(gsParam[gs][netName]) : 0.0 for gs in gsNames, netName in netNames] + +# auprValues = [] +# # Load AUPR values for each GS and file +# for gs in gsNames +# for netName in netNames +# # Load the AUPR value +# filePath = gsParam[gs][netName] +# aupr = loadAupr(filePath) + +# # Use 0.0 if the AUPR value is missing +# push!(auprValues, aupr === nothing ? 0.0 : aupr) +# end +# end +# # Reshape auprValues into a matrix +# auprMatrix = transpose(reshape(auprValues, length(netNames), length(gsNames))) + +# boolNet = any(length(netNames) > 1 for netNames in values(gsParam)) +# numNet = length(netNames) +# numGS = length(gsNames) +# # = +# # ---- Load AUPR values (Works but not using) +# # auprData = Dict(gs => Dict(name => loadAupr(path) for (name, path) in files) for (gs, files) in gsParam) +# # netNames = unique(vcat([collect(keys(files)) for files in values(gsParam)]...)) +# # auprValues = [get(auprData[gs], netName, 0.0) for gs in gsNames, netName in netNames] +# = # +# if boolNet && numGS > 1 +# # Create a figure +# fig, ax = plt.subplots(figsize=figSize, layout="constrained") +# # Define parameters for grouping: +# barWidth = 0.15 # Width of each bar. +# groupSpacing = 0.4 # Extra space between groups. +# groupWidth = numNet * barWidth # Total width occupied by bars in one group. + +# # Compute positions for each group. +# # Compute a vector where each element is the left edge (or starting point) of each group. +# # Use 0-based indexing for groups. +# groupPositions = [i * (groupWidth + groupSpacing) for i in 0:(numGS-1)] + +# # Now plot each bar (each network) in every group. +# # Use a centered offset for each bar in the group. +# for idx in 1:numNet +# # Calculate the offset to center the bar within the group. +# # (idx - (numNet+1)/2)*barWidth shifts bars so they cluster around the center. +# offset = (idx - (numNet+1)/2)*barWidth +# # The x-positions for the bars are: +# # groupPositions + offset + groupWidth/2 +# # Adding groupWidth/2 centers the bars in each group. +# x_positions = [gp + groupWidth/2 + offset for gp in groupPositions] +# # Plot the bar for curr network across all groups. +# ax.bar(x_positions, auprMatrix[:, idx], barWidth, label=netNames[idx], +# color = lineColors[mod1(idx, length(lineColors))]) +# end + +# # Set the xticks to the center positions of each group. +# groupCenters = [gp + groupWidth/2 for gp in groupPositions] +# ax.set_xticks(groupCenters) +# ax.set_xticklabels(gsNames) +# ax.set_xlabel("Gold Standards", fontsize=axisTitleSize) +# ax.set_ylabel("AUPR", fontsize=axisTitleSize) +# ax.tick_params(axis="both", which="both", labelsize=tickLabelSize) +# ax.legend(fontsize=legendSize, loc="center left", bbox_to_anchor=(1, 0.5)) + +# elseif numGS == 1 || numNet == 1 +# # Dot Plot +# fig, ax = subplots(figsize= figSize, layout="constrained") +# if numGS == 1 +# # One GS, multiple files +# for idx in 1:numNet +# ax.scatter(idx, auprMatrix[idx], color=lineColors[idx], label=netNames[idx]) +# ax.text(idx, auprMatrix[idx], netNames[idx], fontsize=9, ha="right") +# end +# ax.set_xlabel("Network", fontsize=axisTitleSize) +# ax.set_xticks([]) +# ax.set_xticklabels([]) +# # ax.grid(true, which="major", linestyle="-", linewidth=0.5, color="lightgray") +# else +# for idx in 1:numGS +# ax.scatter(idx, auprMatrix[idx], color=lineColors[idx], label=gsNames[idx]) +# ax.text(idx, auprMatrix[idx], gsNames[idx], fontsize=9, ha="right") +# end +# ax.set_xlabel("Gold Standard", fontsize=axisTitleSize) +# ax.set_xticks([]) +# ax.set_xticklabels([]) +# # ax.grid(true, which="major", linestyle="-", linewidth=0.5, color="lightgray") +# end +# end + +# # Common settings for both plots. +# ax.grid(true, which="major", linestyle="-", linewidth=0.5, color="lightgray") +# ax.tick_params(axis="both", which="both", labelsize=tickLabelSize) +# # plt.tight_layout() + + +# if !isempty(saveName) +# plt.savefig(joinpath(dirOut, saveName * "_AUPR.pdf"), dpi=600) +# else +# dateStr = Dates.format(now(), "yyyymmdd_HHMMSS") +# plt..savefig(joinpath(dirOut, dateStr * "_AUPR.pdf"), dpi=600) +# end + +# plt.close("all") +# end + + + + + + +# --------------------------------------------------------------------------------------------------------------------------------------------------------------------- +# ------ Using StatsPlots +# --------------------------------------------------------------------------------------------------------------------------------------------------------------------- +# function plotAUPR(gsParam::OrderedDict{String,<:AbstractDict}, dirOut::String, +# saveName::String, plotType::String = "bar"; +# figSize::Tuple{Real,Real} = (8,5), axisTitleSize::Int = 13, +# tickLabelSize::Int = 10) + +# legendFontSize = 9 + +# # Check if dirOut is either nothing or an empty string +# dirOut = if dirOut === nothing || isempty(dirOut) +# pwd() +# else +# mkpath(dirOut) +# end + +# # Dummy load function – replace with your actual data-loading mechanism. +# function loadAupr(filePath) +# try +# data = load(filePath) +# return data["results"][:auprs] +# catch e +# println("Error loading $filePath: $e") +# return nothing +# end +# end + +# # Initial color palette +# lineColors = ["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", +# "#984ea3", "#e41a1c", "#00ced1", "#000000", +# "#5A9D5A", "#D96D3B", "#FFAD12", "#66628D", "#91569A", +# "#B6742A", "#DD87B4", "#D26D7A", "#dede00"] + +# # Build a DataFrame with columns: xGroups, Network, AUPR +# gsNames = collect(keys(gsParam)) +# netNames = unique(vcat([collect(keys(files)) for files in values(gsParam)]...)) +# if length(netNames) > length(lineColors) +# lineColors = padColors(netNames) # Ensure padColors is defined! +# end + +# dfRows = Vector{NamedTuple{(:xGroups, :Network, :AUPR), Tuple{String,String,Float64}}}() +# for gs in gsNames +# for net in netNames +# filePath = haskey(gsParam[gs], net) ? gsParam[gs][net] : nothing +# temp = filePath === nothing ? nothing : loadAupr(filePath) +# val = (temp === nothing || temp === missing) ? 0.0 : temp +# push!(dfRows, (xGroups = gs, Network = net, AUPR = val)) +# end +# end +# df = DataFrame(dfRows) + +# # Number of unique groups. +# numGS = length(unique(df.xGroups)) +# numNet = length(unique(df.Network)) + +# # Set common labels and legend; use camelCase. +# xLabelVal = (numGS == 1 ? "Network" : "Gold Standards") +# yLabelVal = "AUPR" +# legendPos = :outerright + +# # Common arguments (for scatter and bar, without color arguments). +# commonArgs = (xlabel = xLabelVal, ylabel = yLabelVal, legend = legendPos, +# guidefontsize = axisTitleSize, tickfontsize = tickLabelSize, legendfontsize = legendFontSize, +# framestyle = :box) + +# colors = lineColors[1:numNet] + +# p = nothing +# if lowercase(plotType) == "scatter" +# if numGS == 1 +# p = @df df sp.scatter(:Network, :AUPR; markersize = 6, +# color = colors, xrotation = 45, commonArgs...) +# elseif numNet == 1 +# p = @df df sp.scatter(:xGroups, :AUPR; markersize = 6, +# color = lineColors[1], xrotation = 45, commonArgs...) +# else +# p = @df df sp.scatter(:xGroups, :AUPR; group = :Network, markersize = 6, +# palette = colors, commonArgs...) +# end + +# elseif lowercase(plotType) == "bar" +# if numGS == 1 +# p = @df df sp.bar(:Network, :AUPR; +# color = colors, xrotation = 45, commonArgs...) +# elseif numNet == 1 +# p = @df df sp.bar(:xGroups, :AUPR; +# color = lineColors[1], xrotation = 45, commonArgs...) +# else +# p = @df df sp.groupedbar(:xGroups, :AUPR; group = :Network, +# palette = colors, commonArgs...) +# end +# else +# error("Invalid plot type. Please choose 'scatter' or 'bar'") +# end + +# # Adjust figure size (e.g. figSize = (6,3.5) converts to (600,350) pixels) +# p = sp.plot(p, size = (Int(figSize[1]*100), Int(figSize[2]*100)), titlefontsize = axisTitleSize) + + +# if !isempty(saveName) +# savePath = joinpath(dirOut, saveName * "_AUPR_$(plotType).pdf") +# sp.savefig(p, savePath) +# else +# dateStr = Dates.format(now(), "yyyymmdd_HHMMSS") +# savePath = joinpath(dirOut, dateStr * "_AUPR_$(plotType).pdf") +# sp.savefig(p, savePath) +# end +# end + + + diff --git a/src/metrics/plotting/PlotSingle.jl b/src/metrics/plotting/PlotSingle.jl new file mode 100755 index 0000000..a202236 --- /dev/null +++ b/src/metrics/plotting/PlotSingle.jl @@ -0,0 +1,94 @@ +""" + plotPRCurve(rec, prec, randPR, saveDir; kwargs...) + +Plot and save a Precision-Recall (PR) curve using PyPlot. + +# Arguments +- `rec::AbstractVector` : Recall values. +- `prec::AbstractVector` : Precision values. +- `randPR::Float64` : Random baseline precision (horizontal reference line). +- `saveDir::String` : Directory to save the plot. + +# Keyword Arguments +- `xLimitRecall::Float64=1.0` : Maximum recall shown on the x-axis of diagnostic PR plots. +- `axisTitleSize::Int=16` : Font size for axis titles. +- `tickLabelSize::Int=14` : Font size for tick labels. +- `plotTitleSize::Int=18` : Font size for plot title. +- `baseName=nothing` : Base filename. If `nothing`, a timestamp is used. +""" +function plotPRCurve( + rec::AbstractVector, prec::AbstractVector, randPR::Float64, saveDir::String; + xLimitRecall::Float64 = 1.0, axisTitleSize::Int = 16, + tickLabelSize::Int = 14, plotTitleSize::Int = 18, baseName = nothing) + + if baseName === nothing + baseName = Dates.format(now(), "yyyymmdd_HHMMSS") + end + + if !isempty(rec) && !isempty(prec) && any(prec .> 0) + PyPlot.figure() + PyPlot.axhline(y=randPR, linestyle="-.", color="k") + PyPlot.plot(rec, prec, color="b") + PyPlot.xlabel("Recall", fontsize=axisTitleSize) + PyPlot.ylabel("Precision", fontsize=axisTitleSize) + PyPlot.title("PR Curve: $baseName", fontsize=plotTitleSize) + PyPlot.xlim(0, xLimitRecall) + PyPlot.ylim(0, 1) + PyPlot.grid(true, which="major", linestyle="--", linewidth=0.75, color="gray") + PyPlot.minorticks_on() + PyPlot.grid(true, which="minor", linestyle=":", linewidth=0.5, color="lightgray") + PyPlot.tick_params(axis="both", which="major", labelsize=tickLabelSize) + PyPlot.tick_params(axis="both", which="minor", labelsize=tickLabelSize - 2) + PyPlot.savefig(joinpath(saveDir, baseName * "_PR.png"), dpi=600) + PyPlot.close() + else + @warn "Skipping PR plot for $baseName: empty or all-zero precision/recall vectors" + end +end + + +""" + plotROCCurve(fpr, tpr, saveDir; kwargs...) + +Plot and save a ROC curve using PyPlot. + +# Arguments +- `fpr::AbstractVector` : False positive rate values (x-axis). +- `tpr::AbstractVector` : True positive rate / recall values (y-axis). +- `saveDir::String` : Directory to save the plot. + +# Keyword Arguments +- `axisTitleSize::Int=16` : Font size for axis titles. +- `tickLabelSize::Int=14` : Font size for tick labels. +- `plotTitleSize::Int=18` : Font size for plot title. +- `baseName=nothing` : Base filename. If `nothing`, a timestamp is used. +""" +function plotROCCurve( + fpr::AbstractVector, tpr::AbstractVector, saveDir::String; + axisTitleSize::Int = 16, tickLabelSize::Int = 14, + plotTitleSize::Int = 18, baseName = nothing) + + if baseName === nothing + baseName = Dates.format(now(), "yyyymmdd_HHMMSS") + end + + if !isempty(fpr) && !isempty(tpr) && (maximum(tpr) > 0 || maximum(fpr) > 0) + PyPlot.figure() + PyPlot.plot([0, 1], [0, 1], linestyle="--", color="k") + PyPlot.plot(fpr, tpr, color="b") + PyPlot.xlabel("FPR", fontsize=axisTitleSize) + PyPlot.ylabel("TPR", fontsize=axisTitleSize) + PyPlot.title("ROC Curve: $baseName", fontsize=plotTitleSize) + PyPlot.xlim(0, 1) + PyPlot.ylim(0, 1) + PyPlot.grid(true, which="major", linestyle="--", linewidth=0.75, color="gray") + PyPlot.minorticks_on() + PyPlot.grid(true, which="minor", linestyle=":", linewidth=0.5, color="lightgray") + PyPlot.tick_params(axis="both", which="major", labelsize=tickLabelSize) + PyPlot.tick_params(axis="both", which="minor", labelsize=tickLabelSize - 2) + PyPlot.savefig(joinpath(saveDir, baseName * "_ROC.png"), dpi=600) + PyPlot.close() + else + @warn "Skipping ROC plot for $baseName: empty or all-zero TPR/FPR vectors" + end +end \ No newline at end of file diff --git a/src/prior/MergeDegenerateTFs.jl b/src/prior/MergeDegenerateTFs.jl new file mode 100755 index 0000000..671cb7a --- /dev/null +++ b/src/prior/MergeDegenerateTFs.jl @@ -0,0 +1,366 @@ + using CSV + using DataFrames + using Printf + using FileIO + using Base.Iterators: partition + using Base.Threads + + """ + This script merges transcription factors (TFs) with identical target sets and interaction weights + into meta-TFs, producing a merged prior network and summary tables. + + # Inputs + - `networkFile::String` — Path to the input network file. + - `outFileBase::Union{String, Nothing}` — Base path for output files. If `nothing`, defaults to the input file's base name. + - `fileFormat::Int` — Format of the input network: + - `1` = Long format (TF, Target, Weight) + - `2` = Wide format (Targets × TFs matrix) + - `connector::String` (default = "_") — String used to join TF names into merged meta-TF names. + + # Output Files + 1. `*_merged_sp.tsv` — Long-format network with merged TF names. + 2. `*_merged.tsv` — Wide-format matrix of the merged network (targets × regulators). + 3. `*_overlaps.tsv` — TF-by-TF overlap matrix showing shared targets. + 4. `*_targetTotals.tsv` — Number of targets per (possibly merged) TF. + 5. `*_mergedTfs.tsv` — Mappings from each merged meta-TF to its original TF members. + + # Notes + - Run Julia with multiple threads (e.g., `julia --threads=6`) to enable parallel processing. + - Ensure that TF names do not already include the `connector` string to avoid unintended merges. + - Zeros are removed in the long-format representation before merging. + - The returned merged network is written in wide format and can be used for downstream analysis. + + # Returns + A `NamedTuple` with: + - `merged::DataFrame` — Wide-format merged matrix (targets x meta-TFs). + - `mergedTfs::Matrix{String}` — A two-column matrix where each row contains a meta-TF name and a comma-separated list of its constituent TFs. + """ + + # Struct defined in src/Types.jl + + + function countOverlap(s1::AbstractSet, s2::AbstractSet) + """ + Count the number of shared elements between two sets. + + Returns the size of the intersection without materializing a new set. + """ + # Iterate over the smaller set + if length(s1) > length(s2) + s1, s2 = s2, s1 + end + c = 0 + for t in s1 + if t in s2 + c += 1 + end + end + return c + end + + function readNetwork(networkFile::String; fileFormat::Int) + """ + Parse a regulatory network file in either long or wide format. + + # Arguments + - `networkFile::String`: Path to the network file. + - `fileFormat::Int`: + - `1` = long format (TF, Target, Weight) + - `2` = wide format (Targets × TFs matrix) + + # Returns + - A dictionary mapping TF names to sets of (target, weight) tuples. + + # Notes + - Entries with weight zero are ignored. + - Weights are stored as strings (converted to Float64 later). + """ + # Initialize dictionary to store TF-target interactions + tfTargDic = Dict{String, Set{Tuple{String, String}}}() + + # Parse the network file based on format + if fileFormat == 1 + # Long format (direct TF-target-weight representation) + # Long format: each line has the form: TFTargetWeight. + open(networkFile, "r") do file + readline(file) # skip header + for line in eachline(file) + line = strip(line) + if isempty(line) + continue + end + parts = split(line, '\t') + if length(parts) < 3 + continue + end + tfName, target, weight = parts[1], parts[2], parts[3] + if weight != "0" + if !haskey(tfTargDic, tfName) + tfTargDic[tfName] = Set{Tuple{String, String}}() + end + push!(tfTargDic[tfName], (target, weight)) + end + end + println("Dict Created!!!") + end + + elseif fileFormat == 2 + # Wide format (TFs as columns, Targets as rows, and cells as interactions) + df = CSV.File(networkFile, header=1; delim='\t') |> DataFrame + melted = stack(df, Not(1)) + melted = rename!(melted, names(melted)[1] => :Target, names(melted)[2] => :TF, names(melted)[3] => :Weight) + melted = select!(melted, :TF, :Target, :Weight) + melted = filter(row -> row.Weight != 0, melted) + for row in eachrow(melted) + tfName = row.TF + target = row.Target + weight = row.Weight + if !(tfName in keys(tfTargDic)) + tfTargDic[tfName] = Set() + end + push!(tfTargDic[tfName], (target, string(weight))) + end + end + return tfTargDic + end + + + function groupRedundantTFs(tfTargDic::Dict{String, Set{Tuple{String, String}}}) + """ + Compute overlapping TFs with identical target–weight sets and group them into meta-TFs. + + # Arguments + - `tfTargDic::Dict{String, Set{Tuple{String, String}}}`: Mapping from TFs to their target–weight pairs. + + # Returns + - `tfMergers::Dict{String, Vector{String}}`: TFs grouped into merge sets. + - `overlaps::Dict{Tuple{String, String}, Int}`: Pairwise TF overlap counts. + - `tfTargNums::Dict{String, Int}`: Number of targets per TF. + - `tfNames::Vector{String}`: All TF names (sorted). + + # Notes + - Uses threading if multiple threads are available. + - TFs are merged if they have exactly the same set of target–weight pairs. + """ + # Check if more than 1 thread is available + nthreads = Threads.nthreads() + use_threads = nthreads > 1 + + # Step A: Determine TF overlaps + tfNames = sort(collect(keys(tfTargDic))) + tfMergers = Dict{String, Vector{String}}() + overlaps = Dict{Tuple{String, String}, Int}() + tfTargNums = Dict{String, Int}() + + # Fill tfTargNums and diagonal of overlap matrix + for (tf, targets) in tfTargDic + n = length(targets) + tfTargNums[tf] = n + overlaps[(tf,tf)] = n #seed the diagonal of the overlaps matrix + end + + if use_threads + localOverlapsArray = [Dict{Tuple{String, String}, Int}() for i in 1:nthreads] + localTFMergersArray = [Dict{String, Set{String}}() for i in 1:nthreads] + + println("Running in threaded mode ...") + Threads.@threads for i in 1:length(tfNames) + # Use the thread's id (an integer between 1 and nthreads) to access its local storage. + tid = threadid() + localOverlaps = localOverlapsArray[tid] + localMergers = localTFMergersArray[tid] + + tf1 = tfNames[i] + tf1targets = tfTargDic[tf1] + for j in (i+1):length(tfNames) + tf2 = tfNames[j] + tf2targets = tfTargDic[tf2] + overlapSize = countOverlap(tf1targets, tf2targets) + + # Save both orderings if needed. + localOverlaps[(tf1, tf2)] = overlapSize + localOverlaps[(tf2, tf1)] = overlapSize + + # If the overlap equals each TF's target count then they fully overlap. + if tfTargNums[tf1] == overlapSize && tfTargNums[tf2] == overlapSize + # Update local mergers for tf1. + localMergers[tf1] = get!(localMergers, tf1, Set([tf1])) + push!(localMergers[tf1], tf2) + + # Update local mergers for tf2. + localMergers[tf2] = get!(localMergers, tf2, Set([tf2])) + push!(localMergers[tf2], tf1) + end + end + end + + # Merge the thread-local results into global dictionaries. + for lo in localOverlapsArray + for (k,v) in lo + overlaps[k] = v + end + end + + for localx in localTFMergersArray + for (tf, mergerSet) in localx + if haskey(tfMergers, tf) + # Combine the sets, then convert back to vector making sure they're unique. + unionSet = union(Set(tfMergers[tf]), mergerSet) + tfMergers[tf] = collect(unionSet) + else + tfMergers[tf] = collect(mergerSet) + end + end + end + + else + println("Running in sequential mode ...") + for i in 1:length(tfNames) + + tf1 = tfNames[i] + tf1targets = tfTargDic[tf1] + + for j in i+1:length(tfNames) + tf2 = tfNames[j] + tf2targets = tfTargDic[tf2] + + overlapSize = countOverlap(tf1targets, tf2targets) + overlaps[(tf1, tf2)] = overlapSize + overlaps[(tf2, tf1)] = overlapSize + + if tfTargNums[tf1] == tfTargNums[tf2] == overlapSize + # Use get! to set default values and append + push!(get!(tfMergers, tf1, [tf1]), tf2) + push!(get!(tfMergers, tf2, [tf2]), tf1) + end + end + end + end + println("TF Overlaps Determination Complete!!") + + return tfMergers, overlaps, tfTargNums, tfNames + end + + + # Function to merge degenerate prior TFs + function mergeDegenerateTFs( + mergedTFsData::mergedTFsResult, + networkFile::String; + outFileBase::Union{String,Nothing}=nothing, + fileFormat::Int = 2, + connector::String = "_" + ) + """ + merge_degenerate_priors( + networkFile::String; + outFileBase::Union{String,Nothing}=nothing, + fileFormat::Int=1, + connector::String="_", + write_files::Bool=true + ) + -- NamedTuple{(:merged, :mergedTfs),Tuple{DataFrame,Vector{String}}} + + - networkFile – path to your priorFile + - outFileBase – optional "base path+stem" for writing the output files; + if `nothing` we auto‐derive it from `networkFile`. + - fileFormat – 1=long, 2=wide + - connector – string between merged TF names, default "_" + - write_files – if true (default) write the five output files; + if false, run purely in‐memory. + + Returns a NamedTuple + - merged = wide‐format DataFrame (targets×regulators) + - mergedTfs = Vector{String} of all the new meta‐TF names + """ + + # 1. ----- Read + compute mergers, overlaps, counts, names... + tfTargDic = readNetwork(networkFile; fileFormat) + tfMergers, overlaps, tfTargNums, tfNames = groupRedundantTFs(tfTargDic) + + # 2. ----- Write output files + + # If no outFileBase was supplied, derive it from networkFile: + stem = splitext(basename(networkFile))[1] # filename without extension + if outFileBase === nothing + dir = dirname(networkFile) + outFileBase = joinpath(dir, stem) + else + outFileBase = joinpath(outFileBase,stem) + end + netOutFile = outFileBase * "_merged_sp.tsv" + netMatOutFile = outFileBase * "_merged.tsv" + overlapsOutFile = outFileBase * "_overlaps.tsv" + totTargOutFile = outFileBase * "_targetTotals.tsv" + mergedTfsOutFile = outFileBase * "_mergedTFs.tsv" + + # Open all + netIO = open(netOutFile, "w") + overlapsIO = open(overlapsOutFile, "w") + totTargIO = open(totTargOutFile, "w") + mergedTfsIO = open(mergedTfsOutFile, "w") + + # In-memory collector + netDF = DataFrame(Regulator=String[], Target=String[], Weight=String[]) + tabMergedTFs = Vector{Vector{String}}() # will hold lines "metaTF\tmember1, member2,…" + allMergedTfs = collect(keys(tfMergers)) + usedMergedTfs = Set{String}() # keeps track of used TFs, so we don't output mergers twice + printedTfs = String[] + + try + # Write header to network output file. + println(netIO, "Regulator\tTarget\tWeight") + for tf in tfNames + tfPrint = nothing + doPrint = false + if tf in allMergedTfs && !(tf in usedMergedTfs) + mergedTfs = sort(collect(tfMergers[tf])) + union!(usedMergedTfs, mergedTfs) + tfPrint = join(mergedTfs[1:2], connector) * (length(mergedTfs) > 2 ? "..." : "") + # 1) write merged‐TF mapping to disk + line = tfPrint * "\t" * join(mergedTfs, ", ") + println(mergedTfsIO, line) + # 2) store in memory + push!(tabMergedTFs, [tfPrint, join(mergedTfs, ", ")]) + doPrint = true + + elseif !(tf in allMergedTfs) + tfPrint = tf + doPrint = true + end + + if doPrint + # Write target totals. + println(totTargIO, "$(tfPrint)\t$(tfTargNums[tf])") + # Write each TF-target-weight record to a line. + for (targ, wgt) in tfTargDic[tf] + outline = "$(tfPrint)\t$(targ)\t$(wgt)" + println(netIO, outline) + # also push into netDF + push!(netDF, (Regulator=tfPrint, Target=targ, Weight=wgt)) + end + push!(printedTfs, tfPrint) + end + end + + println(overlapsIO, "\t" * join(printedTfs, "\t")) + for tf in printedTfs + row = [ string(overlaps[(first(split(tf,connector)), first(split(pt,connector)))]) + for pt in printedTfs ] + println(overlapsIO, tf * '\t' * join(row, '\t')) + end + println("Overlap Output File Successfully Written!!!") + finally + close(netIO); close(overlapsIO); close(totTargIO); close(mergedTfsIO) + end + + netDF.Weight = parse.(Float64, netDF.Weight) + mergedPrior = convertToWide(netDF; indices=(2, 1, 3)) + mergedPrior .= coalesce.(mergedPrior, 0.0) + # write to file, while makig sure the first column is unnamed. + writeTSVWithEmptyFirstHeader(mergedPrior, netMatOutFile; delim ='\t') + + println("Output files:\n$mergedTfsIO\n$totTargOutFile\n$netOutFile\n$overlapsOutFile") + mergedTFsData.mergedPrior = mergedPrior + mergedTFsData.mergedTFMap = reduce(vcat, permutedims.(tabMergedTFs)) # Convert tabMergedTFs to a two columns matrix and then saves + + end diff --git a/test/runtests.jl b/test/runtests.jl new file mode 100755 index 0000000..55314bf --- /dev/null +++ b/test/runtests.jl @@ -0,0 +1,497 @@ +# test/runtests.jl +using Test +using InferelatorJL +using Random +using Statistics +using LinearAlgebra + +# ============================================================================= +# Helpers — write synthetic input files to a temp directory +# ============================================================================= + +""" +Write a gene expression TSV: + - Row 1: empty cell + sample names + - Rows 2+: gene name + Float64 values +""" +function write_expr_tsv(path, geneNames, sampleNames, mat) + open(path, "w") do io + println(io, "\t" * join(sampleNames, "\t")) + for (i, g) in enumerate(geneNames) + println(io, g * "\t" * join(string.(mat[i, :]), "\t")) + end + end +end + +""" +Write a gene list file (one name per line). +""" +write_gene_list(path, names) = write(path, join(names, "\n") * "\n") + +""" +Write a sparse prior TSV: + - Row 1: empty cell + TF names + - Rows 2+: gene name + Float64 values +""" +function write_prior_tsv(path, tfNames, geneNames, mat) + open(path, "w") do io + println(io, "\t" * join(tfNames, "\t")) + for (i, g) in enumerate(geneNames) + println(io, g * "\t" * join(string.(mat[i, :]), "\t")) + end + end +end + + +# ============================================================================= +@testset "InferelatorJL" begin +# ============================================================================= + + +# ----------------------------------------------------------------------------- +@testset "Structs instantiate" begin +# ----------------------------------------------------------------------------- + @test GeneExpressionData() isa GeneExpressionData + @test PriorTFAData() isa PriorTFAData + @test mergedTFsResult() isa mergedTFsResult + @test GrnData() isa GrnData + @test BuildGrn() isa BuildGrn +end + + +# ----------------------------------------------------------------------------- +@testset "Data loading — expression matrix" begin +# ----------------------------------------------------------------------------- + tmpdir = mktempdir() + + # Genes intentionally out of alphabetical order to test sort + genes = ["Zap70", "Akt1", "Myc", "Foxp3"] + samples = ["S1", "S2", "S3", "S4", "S5"] + mat = Float64[ + 1.0 2.0 3.0 4.0 5.0; # Zap70 + 6.0 7.0 8.0 9.0 10.0; # Akt1 + 0.1 0.2 0.1 0.2 0.1; # Myc + 3.0 3.5 4.0 4.5 5.0; # Foxp3 + ] + exprFile = joinpath(tmpdir, "expr.txt") + write_expr_tsv(exprFile, genes, samples, mat) + + data = GeneExpressionData() + InferelatorJL.loadExpressionData!(data, exprFile) + + # Gene names should be sorted alphabetically + @test data.geneNames == sort(genes) + + # Matrix rows must match the sorted order + sorted_order = sortperm(genes) # [2,4,3,1] → Akt1,Foxp3,Myc,Zap70 + @test data.geneExpressionMat == mat[sorted_order, :] + + # Sample labels are parsed from the header + @test data.cellLabels == samples + + # Dimensions: genes × samples + @test size(data.geneExpressionMat) == (length(genes), length(samples)) +end + + +# ----------------------------------------------------------------------------- +@testset "Data loading — non-numeric first column values are gene names" begin +# ----------------------------------------------------------------------------- + # Verify that a gene name like "1500009L16Rik" (starts with a digit) is + # treated as a gene name, not discarded or confused with a header. + tmpdir = mktempdir() + genes = ["1500009L16Rik", "Bcl2", "Cd44"] + samples = ["A", "B", "C"] + mat = [1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0] + + exprFile = joinpath(tmpdir, "expr_numnames.txt") + write_expr_tsv(exprFile, genes, samples, mat) + + data = GeneExpressionData() + InferelatorJL.loadExpressionData!(data, exprFile) + + @test "1500009L16Rik" in data.geneNames + @test length(data.geneNames) == 3 +end + + +# ----------------------------------------------------------------------------- +@testset "Data loading — invalid file path throws" begin +# ----------------------------------------------------------------------------- + data = GeneExpressionData() + @test_throws ErrorException InferelatorJL.loadExpressionData!(data, "/nonexistent/path/expr.txt") +end + + +# ----------------------------------------------------------------------------- +@testset "Data loading — target gene filtering" begin +# ----------------------------------------------------------------------------- + tmpdir = mktempdir() + genes = ["Akt1", "Bcl2", "Myc", "Foxp3", "Stat3"] + samples = ["S1", "S2", "S3", "S4", "S5", "S6"] + + # Myc has zero variance — should be removed with default epsilon + mat = Float64[ + 1.0 2.0 3.0 4.0 5.0 6.0; # Akt1 — good variance + 2.0 2.5 3.0 2.5 2.0 3.5; # Bcl2 — good variance + 1.0 1.0 1.0 1.0 1.0 1.0; # Myc — ZERO variance + 0.5 1.5 2.5 3.5 4.5 5.5; # Foxp3 — good variance + 3.0 1.0 4.0 1.0 5.0 9.0; # Stat3 — good variance + ] + exprFile = joinpath(tmpdir, "expr.txt") + write_expr_tsv(exprFile, genes, samples, mat) + + # Request 4 genes (exclude Stat3); Myc will be removed by variance filter + targFile = joinpath(tmpdir, "targs.txt") + write_gene_list(targFile, ["Akt1", "Bcl2", "Myc", "Foxp3"]) + + data = GeneExpressionData() + InferelatorJL.loadExpressionData!(data, exprFile) + InferelatorJL.loadAndFilterTargetGenes!(data, targFile; epsilon = 0.01) + + # Myc must have been removed + @test !("Myc" in data.targGenes) + + # The other three requested genes should be present + @test "Akt1" in data.targGenes + @test "Bcl2" in data.targGenes + @test "Foxp3" in data.targGenes + + # targGeneMat dimensions must match retained genes × samples + @test size(data.targGeneMat) == (3, length(samples)) +end + + +# ----------------------------------------------------------------------------- +@testset "Data loading — target file not found throws" begin +# ----------------------------------------------------------------------------- + tmpdir = mktempdir() + genes = ["Akt1", "Bcl2"] + samples = ["S1", "S2", "S3"] + mat = [1.0 2.0 3.0; 4.0 5.0 6.0] + exprFile = joinpath(tmpdir, "expr.txt") + write_expr_tsv(exprFile, genes, samples, mat) + + data = GeneExpressionData() + InferelatorJL.loadExpressionData!(data, exprFile) + @test_throws ErrorException InferelatorJL.loadAndFilterTargetGenes!( + data, "/no/such/file.txt") +end + + +# ----------------------------------------------------------------------------- +@testset "Data loading — potential regulators" begin +# ----------------------------------------------------------------------------- + tmpdir = mktempdir() + genes = ["Akt1", "Bcl2", "Foxp3", "Myc"] + samples = ["S1", "S2", "S3"] + mat = [1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0; 10.0 11.0 12.0] + exprFile = joinpath(tmpdir, "expr.txt") + write_expr_tsv(exprFile, genes, samples, mat) + + # Regulator list: 2 present in expression, 1 not present + regFile = joinpath(tmpdir, "regs.txt") + write_gene_list(regFile, ["Foxp3", "Akt1", "NotInExpr"]) + + data = GeneExpressionData() + InferelatorJL.loadExpressionData!(data, exprFile) + InferelatorJL.loadPotentialRegulators!(data, regFile) + + # Only the 2 regulators present in expression data should appear + @test length(data.potRegs) == 2 + @test "Foxp3" in data.potRegs + @test "Akt1" in data.potRegs + @test !("NotInExpr" in data.potRegs) + + # mRNA matrix: rows = regulators that have expression, cols = samples + @test size(data.potRegMatmRNA) == (2, length(samples)) +end + + +# ============================================================================= +# Helper — build a fully loaded GeneExpressionData from scratch in memory +# ============================================================================= +function make_test_data(tmpdir) + genes = ["Akt1", "Bcl2", "Foxp3", "Myc", "Stat3"] + tfs = ["Foxp3", "Stat3"] # regulators + samples = ["S$i" for i in 1:10] + + Random.seed!(42) + mat = rand(Float64, length(genes), length(samples)) .* 5 .+ 0.5 + + exprFile = joinpath(tmpdir, "expr.txt") + targFile = joinpath(tmpdir, "targs.txt") + regFile = joinpath(tmpdir, "regs.txt") + + write_expr_tsv(exprFile, genes, samples, mat) + write_gene_list(targFile, genes) # all genes are targets + write_gene_list(regFile, tfs) + + data = GeneExpressionData() + InferelatorJL.loadExpressionData!(data, exprFile) + InferelatorJL.loadAndFilterTargetGenes!(data, targFile; epsilon = 0.01) + InferelatorJL.loadPotentialRegulators!(data, regFile) + InferelatorJL.processTFAGenes!(data, "") # use all genes for TFA + + return data, genes, tfs, samples +end + + +# ----------------------------------------------------------------------------- +@testset "TFA — output dimensions" begin +# ----------------------------------------------------------------------------- + tmpdir = mktempdir() + data, genes, tfs, samples = make_test_data(tmpdir) + + nTFs = length(tfs) + nGenes = length(genes) + nSamples = length(samples) + + # Prior: every gene has at least one edge per TF + # Shape: genes × tfs (written as rows=genes, cols=tfs) + prior_mat = Float64[ + 1 0; # Akt1 → Foxp3 + 1 1; # Bcl2 → Foxp3, Stat3 + 0 1; # Foxp3 → Stat3 + 1 1; # Myc → both + 1 1; # Stat3 → both + ] + priorFile = joinpath(tmpdir, "prior.tsv") + write_prior_tsv(priorFile, tfs, sort(genes), prior_mat) + + tfaData = PriorTFAData() + InferelatorJL.processPriorFile!(tfaData, data, priorFile; + mergedTFsData = nothing, + minTargets = 1) + InferelatorJL.calculateTFA!(tfaData, data; edgeSS = 0, zTarget = false) + + # medTfas must be nTFs × nSamples (result of priorMatrix \ targExpression) + @test size(tfaData.medTfas, 1) == length(tfaData.pRegs) + @test size(tfaData.medTfas, 2) == nSamples +end + + +# ----------------------------------------------------------------------------- +@testset "TFA — least-squares correctness" begin +# ----------------------------------------------------------------------------- + # With a square, well-conditioned prior, TFA = prior \ expression exactly. + tmpdir = mktempdir() + + genes = ["Akt1", "Bcl2", "Foxp3"] + tfs = ["Foxp3", "Akt1"] + samples = ["S1", "S2", "S3", "S4"] + + Random.seed!(7) + expr_mat = rand(Float64, length(genes), length(samples)) .+ 1.0 + + exprFile = joinpath(tmpdir, "expr.txt") + targFile = joinpath(tmpdir, "targs.txt") + regFile = joinpath(tmpdir, "regs.txt") + write_expr_tsv(exprFile, genes, samples, expr_mat) + write_gene_list(targFile, genes) + write_gene_list(regFile, tfs) + + # Full prior: all 3 genes × 2 TFs, every cell nonzero (both TFs have >1 target) + prior_mat = Float64[ + 1.0 0.5; + 0.5 1.0; + 0.8 0.3; + ] + priorFile = joinpath(tmpdir, "prior.tsv") + write_prior_tsv(priorFile, tfs, genes, prior_mat) + + data = GeneExpressionData() + InferelatorJL.loadExpressionData!(data, exprFile) + InferelatorJL.loadAndFilterTargetGenes!(data, targFile; epsilon = 0.01) + InferelatorJL.loadPotentialRegulators!(data, regFile) + InferelatorJL.processTFAGenes!(data, "") + + tfaData = PriorTFAData() + InferelatorJL.processPriorFile!(tfaData, data, priorFile; + mergedTFsData = nothing, minTargets = 1) + InferelatorJL.calculateTFA!(tfaData, data; edgeSS = 0, zTarget = false) + + # Manually compute expected TFA: priorMatrix \ targExpression + expected = tfaData.priorMatrix \ tfaData.targExpression + @test tfaData.medTfas ≈ expected atol=1e-10 +end + + +# ----------------------------------------------------------------------------- +@testset "TFA — z-scored targets change output" begin +# ----------------------------------------------------------------------------- + tmpdir = mktempdir() + data, genes, tfs, samples = make_test_data(tmpdir) + + prior_mat = Float64[ + 1 0; 1 1; 0 1; 1 1; 1 1; + ] + priorFile = joinpath(tmpdir, "prior.tsv") + write_prior_tsv(priorFile, tfs, sort(genes), prior_mat) + + tfaData_raw = PriorTFAData() + InferelatorJL.processPriorFile!(tfaData_raw, data, priorFile; + mergedTFsData = nothing, minTargets = 1) + InferelatorJL.calculateTFA!(tfaData_raw, data; edgeSS = 0, zTarget = false) + + tfaData_z = PriorTFAData() + InferelatorJL.processPriorFile!(tfaData_z, data, priorFile; + mergedTFsData = nothing, minTargets = 1) + InferelatorJL.calculateTFA!(tfaData_z, data; edgeSS = 0, zTarget = true) + + # z-scored and non-z-scored TFA should differ + @test !isapprox(tfaData_raw.medTfas, tfaData_z.medTfas; atol=1e-6) +end + + +# ----------------------------------------------------------------------------- +@testset "Prior — minTargets filter removes low-coverage TFs" begin +# ----------------------------------------------------------------------------- + tmpdir = mktempdir() + data, genes, tfs, samples = make_test_data(tmpdir) + + # TF1 (Foxp3) has 4 targets; TF2 (Stat3) has only 1 target + prior_mat = Float64[ + 1 0; # Akt1 → Foxp3 only + 1 0; # Bcl2 → Foxp3 only + 1 1; # Foxp3 → both + 1 0; # Myc → Foxp3 only + 0 1; # Stat3 → Stat3 only (1 target) + ] + priorFile = joinpath(tmpdir, "prior.tsv") + write_prior_tsv(priorFile, tfs, sort(genes), prior_mat) + + tfaData = PriorTFAData() + # minTargets = 3: Foxp3 has 4 targets (passes), Stat3 has 1 (fails) + InferelatorJL.processPriorFile!(tfaData, data, priorFile; + mergedTFsData = nothing, minTargets = 3) + + @test "Foxp3" in tfaData.pRegs + @test !("Stat3" in tfaData.pRegs) +end + + +# ----------------------------------------------------------------------------- +@testset "Penalty matrix — self-regulatory edges set to Inf in TFmRNA mode" begin +# ----------------------------------------------------------------------------- + tmpdir = mktempdir() + data, genes, tfs, samples = make_test_data(tmpdir) + + prior_mat = Float64[ + 1 0; 1 1; 0 1; 1 1; 1 1; + ] + priorFile = joinpath(tmpdir, "prior.tsv") + write_prior_tsv(priorFile, tfs, sort(genes), prior_mat) + + tfaData = PriorTFAData() + InferelatorJL.processPriorFile!(tfaData, data, priorFile; + mergedTFsData = nothing, minTargets = 1) + InferelatorJL.calculateTFA!(tfaData, data; edgeSS = 0, zTarget = false) + + grnData = GrnData() + # TFmRNA mode: self-regulatory edges must be Inf + InferelatorJL.preparePredictorMat!(grnData, data, tfaData; tfaOpt = "TFmRNA") + InferelatorJL.preparePenaltyMatrix!(data, grnData; + priorFilePenalties = [priorFile], + lambdaBias = [0.5], + tfaOpt = "TFmRNA") + + # For each TF that is also a target gene, penalty[gene, TF] must be Inf + for tf in data.potRegsmRNA + targIdx = findfirst(==(tf), data.targGenes) + tfIdx = findfirst(==(tf), grnData.allPredictors) + if targIdx !== nothing && tfIdx !== nothing + @test grnData.penaltyMat[targIdx, tfIdx] == Inf + end + end +end + + +# ----------------------------------------------------------------------------- +@testset "Subsamples — shape and reproducibility" begin +# ----------------------------------------------------------------------------- + tmpdir = mktempdir() + data, genes, tfs, samples = make_test_data(tmpdir) + + prior_mat = Float64[1 0; 1 1; 0 1; 1 1; 1 1] + priorFile = joinpath(tmpdir, "prior.tsv") + write_prior_tsv(priorFile, tfs, sort(genes), prior_mat) + + tfaData = PriorTFAData() + InferelatorJL.processPriorFile!(tfaData, data, priorFile; + mergedTFsData = nothing, minTargets = 1) + InferelatorJL.calculateTFA!(tfaData, data; edgeSS = 0, zTarget = false) + + grnData = GrnData() + InferelatorJL.preparePredictorMat!(grnData, data, tfaData; tfaOpt = "") + + totSS = 10 + ssFrac = 0.7 + nSamples = length(data.cellLabels) + expectCols = floor(Int, ssFrac * nSamples) + + Random.seed!(123) + InferelatorJL.constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = ssFrac) + subsamps_a = copy(grnData.subsamps) + + # Correct shape + @test size(subsamps_a) == (totSS, expectCols) + + # All indices are valid sample indices + @test all(1 .<= subsamps_a .<= nSamples) + + # Reproducible with same seed + Random.seed!(123) + InferelatorJL.constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = ssFrac) + subsamps_b = copy(grnData.subsamps) + @test subsamps_a == subsamps_b + + # Different seed → different subsamples (with overwhelming probability) + Random.seed!(999) + InferelatorJL.constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = ssFrac) + @test grnData.subsamps != subsamps_a +end + + +# ----------------------------------------------------------------------------- +@testset "Edge cases — all target genes have zero variance" begin +# ----------------------------------------------------------------------------- + tmpdir = mktempdir() + genes = ["Akt1", "Bcl2"] + samples = ["S1", "S2", "S3"] + # All rows constant — every gene will fail the variance filter + mat = [1.0 1.0 1.0; 2.0 2.0 2.0] + exprFile = joinpath(tmpdir, "expr.txt") + targFile = joinpath(tmpdir, "targs.txt") + write_expr_tsv(exprFile, genes, samples, mat) + write_gene_list(targFile, genes) + + data = GeneExpressionData() + InferelatorJL.loadExpressionData!(data, exprFile) + @test_throws ErrorException InferelatorJL.loadAndFilterTargetGenes!( + data, targFile; epsilon = 0.01) +end + + +# ----------------------------------------------------------------------------- +@testset "Edge cases — target gene list requests genes not in expression data" begin +# ----------------------------------------------------------------------------- + tmpdir = mktempdir() + genes = ["Akt1", "Bcl2"] + samples = ["S1", "S2", "S3"] + mat = [1.0 2.0 3.0; 4.0 5.0 6.0] + exprFile = joinpath(tmpdir, "expr.txt") + targFile = joinpath(tmpdir, "targs.txt") + write_expr_tsv(exprFile, genes, samples, mat) + write_gene_list(targFile, ["NotAGene", "AlsoNotAGene"]) + + data = GeneExpressionData() + InferelatorJL.loadExpressionData!(data, exprFile) + @test_throws ErrorException InferelatorJL.loadAndFilterTargetGenes!( + data, targFile; epsilon = 0.01) +end + + +# ============================================================================= +end # @testset "InferelatorJL" +# =============================================================================