diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 02049be..89f2156 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -2,6 +2,11 @@ on:
pull_request:
branches: [main]
+permissions:
+ contents: write
+ pull-requests: write
+ actions: write
+
jobs:
main:
runs-on: ${{ matrix.os }}
diff --git a/.isort.cfg b/.isort.cfg
index 9645ea6..130275a 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -1,2 +1,2 @@
[settings]
-known_third_party = click,networkx,pandas
+known_third_party = click,networkx,numpy,pandas
diff --git a/coverage.svg b/coverage.svg
index 3438732..12876e6 100644
--- a/coverage.svg
+++ b/coverage.svg
@@ -9,13 +9,13 @@
-
+
coverage
coverage
- 97%
- 97%
+ 78%
+ 78%
diff --git a/plasnet/alt_label_propagation.py b/plasnet/alt_label_propagation.py
new file mode 100644
index 0000000..7bc3867
--- /dev/null
+++ b/plasnet/alt_label_propagation.py
@@ -0,0 +1,96 @@
+from collections import Counter
+
+from networkx.utils import groups
+
+
+def appendable_lpa_communities(G, initial_labels=None, seed=None):
+ """Returns communities in `G` as detected by asynchronous label
+ propagation.
+
+ The asynchronous label propagation algorithm is described in
+ [1]_. The algorithm is probabilistic and the found communities may
+ vary on different executions.
+
+ The algorithm proceeds as follows. After initializing each node with
+ a unique label, the algorithm repeatedly sets the label of a node to
+ be the label that appears most frequently among that nodes
+ neighbors. The algorithm halts when each node has the label that
+ appears most frequently among its neighbors. The algorithm is
+ asynchronous because each node is updated without waiting for
+ updates on the remaining nodes.
+
+ This generalized version of the algorithm in [1]_ accepts edge
+ weights.
+
+ Parameters
+ ----------
+ G : Graph
+
+ weight : string
+ The edge attribute representing the weight of an edge.
+ If None, each edge is assumed to have weight one. In this
+ algorithm, the weight of an edge is used in determining the
+ frequency with which a label appears among the neighbors of a
+ node: a higher weight means the label appears more often.
+
+ seed : integer, random_state, or None (default)
+ Indicator of random number generation state.
+ See :ref:`Randomness`.
+
+ Returns
+ -------
+ communities : iterable
+ Iterable of communities given as sets of nodes.
+
+ Notes
+ -----
+ Edge weight attributes must be numerical.
+
+ References
+ ----------
+ .. [1] Raghavan, Usha Nandini, Réka Albert, and Soundar Kumara. "Near
+ linear time algorithm to detect community structures in large-scale
+ networks." Physical Review E 76.3 (2007): 036106.
+ """
+
+ if not initial_labels:
+ labels = {n: i for i, n in enumerate(G)}
+ else:
+ start = max(initial_labels.values())
+ H = G.copy()
+ H.remove_nodes_from(initial_labels.keys())
+ labels = {n: i + start for i, n in enumerate(H)}
+ labels.update(initial_labels)
+
+ cont = True
+
+ while cont:
+ cont = False
+ nodes = list(G)
+ seed.shuffle(nodes)
+
+ for node in nodes:
+ if not G[node]:
+ continue
+
+ # Get label frequencies among adjacent nodes.
+ # Depending on the order they are processed in,
+ # some nodes will be in iteration t and others in t-1,
+ # making the algorithm asynchronous.
+ # initialising a Counter from an iterator of labels is
+ # faster for getting unweighted label frequencies
+ label_freq = Counter(map(labels.get, G[node]))
+
+ # Get the labels that appear with maximum frequency.
+ max_freq = max(label_freq.values())
+ best_labels = [label for label, freq in label_freq.items() if freq == max_freq]
+
+ # If the node does not have one of the maximum frequency labels,
+ # randomly choose one of them and update the node's label.
+ # Continue the iteration as long as at least one node
+ # doesn't have a maximum frequency label.
+ if labels[node] not in best_labels:
+ labels[node] = seed.choice(best_labels)
+ cont = True
+
+ yield from groups(labels).values()
diff --git a/plasnet/base_graph.py b/plasnet/base_graph.py
index e160d0a..3d134e9 100644
--- a/plasnet/base_graph.py
+++ b/plasnet/base_graph.py
@@ -143,3 +143,8 @@ def load(cls: Type[BaseGraphType], filepath: Path) -> BaseGraphType:
def write_classification(self, typing_fh: TextIO) -> None:
for node in self.nodes:
typing_fh.write(f"{node}\t{self.label}\n")
+
+ def compare_classification(self, prev_typing: dict, typing_fh: TextIO) -> None:
+ for node in self.nodes:
+ if node in prev_typing.keys():
+ typing_fh.write(f"{node}\t{self.label}\t{prev_typing[node]}\n")
diff --git a/plasnet/clustering_dists.py b/plasnet/clustering_dists.py
new file mode 100644
index 0000000..5de5d23
--- /dev/null
+++ b/plasnet/clustering_dists.py
@@ -0,0 +1,133 @@
+import numpy as np
+import pandas as pd
+
+
+def read_in_clusters(compare_tsv):
+ pling_df = pd.read_csv(compare_tsv, sep="\t")
+ plasmids = list(pling_df["plasmid"].values)
+ clusters_pling = {
+ i: set(pling_df[pling_df["type"] == el]["plasmid"].values)
+ for i, el in enumerate(list(set(pling_df["type"])))
+ }
+ clusters_pling_old = {
+ i: set(pling_df[pling_df["previous_type"] == el]["plasmid"].values)
+ for i, el in enumerate(list(set(pling_df["previous_type"])))
+ }
+ return clusters_pling, clusters_pling_old, plasmids
+
+
+def make_contingency_matrix(
+ clusters_1, clusters_2
+): # clusters_1 and clusters_2 are dictionaries of clusters,
+ # k_1 and k_2 the lengths of the respective dictionaries
+ k_1 = len(clusters_1)
+ k_2 = len(clusters_2)
+ contingency = np.zeros((k_1, k_2))
+ for i in range(k_1):
+ for j in range(k_2):
+ contingency[i][j] = len(clusters_1[i].intersection(clusters_2[j]))
+ return contingency, k_1, k_2
+
+
+def split_join(contingency, k_1, k_2, n): # clusters_1 and clusters_2 are dictionaries of clusters,
+ # n is the total number of data points (plasmids)
+ dist = (
+ 2 * n
+ - sum([max(contingency[i]) for i in range(k_1)])
+ - sum([max(contingency[:, j]) for j in range(k_2)])
+ )
+ return int(dist)
+
+
+def rand_index(contingency):
+ contingency = np.asarray(contingency)
+
+ def comb2(x):
+ return x * (x - 1) / 2.0
+
+ n = contingency.sum()
+ if n <= 1:
+ return 1.0 # degenerate case
+
+ # True positives
+ tp = np.sum(comb2(contingency))
+
+ # Row and column sums
+ row_sums = contingency.sum(axis=1)
+ col_sums = contingency.sum(axis=0)
+
+ sum_rows = np.sum(comb2(row_sums))
+ sum_cols = np.sum(comb2(col_sums))
+
+ fp = sum_cols - tp
+ fn = sum_rows - tp
+
+ total_pairs = comb2(n)
+ tn = total_pairs - tp - fp - fn
+
+ ri = (tp + tn) / total_pairs
+ return ri
+
+
+def adjusted_rand_index(contingency):
+ # Helper function: n choose 2
+ def comb2(x):
+ return x * (x - 1) / 2.0
+
+ n = contingency.sum()
+ if n <= 1:
+ return 0.0
+
+ # Sum over all pairs in cells
+ sum_comb_cells = np.sum(comb2(contingency))
+
+ # Row and column sums
+ row_sums = contingency.sum(axis=1)
+ col_sums = contingency.sum(axis=0)
+
+ sum_comb_rows = np.sum(comb2(row_sums))
+ sum_comb_cols = np.sum(comb2(col_sums))
+
+ total_pairs = comb2(n)
+
+ expected_index = (sum_comb_rows * sum_comb_cols) / total_pairs
+ max_index = 0.5 * (sum_comb_rows + sum_comb_cols)
+
+ denominator = max_index - expected_index
+ if denominator == 0:
+ return 0.0
+
+ ari = (sum_comb_cells - expected_index) / denominator
+ return ari
+
+
+def mutual_information(contingency):
+ n = contingency.sum()
+ if n == 0:
+ return 0.0
+
+ row_sums = contingency.sum(axis=1)
+ col_sums = contingency.sum(axis=0)
+
+ # Only consider nonzero entries
+ nz = contingency > 0
+ nij = contingency[nz]
+
+ # Corresponding row and column sums
+ i_idx, j_idx = np.nonzero(nz)
+ ai = row_sums[i_idx]
+ bj = col_sums[j_idx]
+
+ # Compute MI
+ mi = np.sum((nij / n) * np.log((nij * n) / (ai * bj)))
+
+ return mi
+
+
+def all_clustering_dists(contingency, k_1, k_2, n):
+ dists = {}
+ dists["rand index"] = rand_index(contingency)
+ dists["adjusted rand index"] = adjusted_rand_index(contingency)
+ dists["mutual information"] = mutual_information(contingency)
+ dists["split join"] = split_join(contingency, k_1, k_2, n)
+ return dists
diff --git a/plasnet/community_graph.py b/plasnet/community_graph.py
index dc7729d..c4aeba0 100644
--- a/plasnet/community_graph.py
+++ b/plasnet/community_graph.py
@@ -4,6 +4,7 @@
import networkx as nx
+from plasnet.alt_label_propagation import appendable_lpa_communities
from plasnet.ColorPicker import ColorPicker
from plasnet.hub_graph import HubGraph
from plasnet.subcommunities import Subcommunities
@@ -97,6 +98,116 @@ def split_graph_into_subcommunities(
return Subcommunities(subcommunities)
+ def split_graph_given_labels(
+ self, small_subcommunity_size_threshold: int, typings: list[dict]
+ ) -> Subcommunities:
+ old_plasmids = [
+ plasmid for typing in typings for plasmid in typing.keys() if plasmid in self.nodes
+ ]
+ new_plasmids = [plasmid for plasmid in self.nodes if plasmid not in old_plasmids]
+ new_subcommunities_nodes: list[set[str]] = list(
+ nx.community.asyn_lpa_communities(G=self.subgraph(new_plasmids), seed=42)
+ )
+
+ label = 0
+ map = {}
+ for typing in typings:
+ for i, subcomm in enumerate(typing.values()):
+ map[subcomm] = label + i
+ label = label + len(typing.keys())
+ initial_labels = {n: map[typing[n]] for n in old_plasmids}
+ for subcomm in new_subcommunities_nodes:
+ for plasmid in list(subcomm):
+ initial_labels[plasmid] = label
+ label = label + 1
+ subcommunities_nodes: list[set[str]] = list(
+ appendable_lpa_communities(G=self, initial_labels=initial_labels, seed=42)
+ )
+ subcommunities_nodes = self._fix_small_subcommunities(
+ subcommunities_nodes, small_subcommunity_size_threshold
+ )
+
+ subcommunities = []
+ for subcommunity_index, subcommunity_nodes in enumerate(subcommunities_nodes):
+ colour = ColorPicker.get_color_given_index(subcommunity_index)
+
+ subcommunity = SubcommunityGraph(
+ self.subgraph(subcommunity_nodes),
+ self._hub_connectivity_threshold,
+ self._edge_density,
+ label=f"{self.label}_subcommunity_{subcommunity_index}",
+ colour=colour,
+ )
+ subcommunities.append(subcommunity)
+
+ for node in subcommunity_nodes:
+ self._node_to_colour[node] = colour
+
+ return Subcommunities(subcommunities)
+
+ def nearest_neighbour(self, typing, new_plasmids) -> Subcommunities:
+ subcommunity_names = set(typing["type"].to_list())
+ subcommunity_labels = {
+ subcomm: [plasmid for plasmid in typing[typing["type"] == subcomm]["plasmid"].values]
+ for subcomm in list(subcommunity_names)
+ if subcomm.split("_")[1] == self.label.split("_")[1]
+ } # select only those that are in this community
+ max_label = len(subcommunity_labels.keys())
+
+ for plasmid in new_plasmids:
+ if plasmid in self.nodes:
+ neighbours = [n for n in self[plasmid] if n not in new_plasmids]
+ if len(neighbours) == 0:
+ subcommunity_labels[f"community_{self.label}_subcommunity_{max_label}"] = [
+ plasmid
+ ]
+ max_label = max_label + 1
+ else:
+ neighbours = sorted(
+ neighbours,
+ key=lambda n: self.edges[n, plasmid][DistanceTags.SplitDistanceTag.value],
+ )
+ min_dist = self.edges[neighbours[0], plasmid][
+ DistanceTags.SplitDistanceTag.value
+ ]
+ nearest = [
+ neighbour
+ for neighbour in neighbours
+ if self.edges[neighbour, plasmid][DistanceTags.SplitDistanceTag.value]
+ == min_dist
+ ]
+ nearest = sorted(
+ nearest,
+ key=lambda n: len(
+ typing[
+ typing["type"] == typing[typing["plasmid"] == n]["type"].values[0]
+ ]
+ ),
+ )
+ nn = nearest[-1] # select nearest neighbour with largest subcommunity size
+ subcommunity_labels[typing[typing["plasmid"] == nn]["type"].values[0]].append(
+ plasmid
+ )
+
+ subcommunities = []
+ for subcommunity_label in subcommunity_labels.keys():
+ subcommunity_index = int(subcommunity_label.split("_")[-1])
+ colour = ColorPicker.get_color_given_index(subcommunity_index)
+
+ subcommunity = SubcommunityGraph(
+ self.subgraph(subcommunity_labels[subcommunity_label]),
+ self._hub_connectivity_threshold,
+ self._edge_density,
+ label=subcommunity_label, # reuse old labels here!
+ colour=colour,
+ )
+ subcommunities.append(subcommunity)
+
+ for node in subcommunity_labels[subcommunity_label]:
+ self._node_to_colour[node] = colour
+
+ return Subcommunities(subcommunities)
+
def _get_libs_relative_path(self) -> str:
return ".."
diff --git a/plasnet/list_of_graphs.py b/plasnet/list_of_graphs.py
index c2c4f6c..d7c2660 100644
--- a/plasnet/list_of_graphs.py
+++ b/plasnet/list_of_graphs.py
@@ -1,6 +1,6 @@
import pickle
from pathlib import Path
-from typing import Generator, cast
+from typing import Generator, Optional, cast
from plasnet.base_graph import BaseGraphType
@@ -27,11 +27,16 @@ def save_graph_as_text(self, filepath: Path) -> None:
for graph_as_text in self._get_each_graph_as_list_of_nodes_in_text_format():
print(graph_as_text, file=fh)
- def save_classification(self, filepath: Path, header: str) -> None:
+ def save_classification(
+ self, filepath: Path, header: str, prev_typing: Optional[dict] = None
+ ) -> None:
with open(filepath, "w") as fh:
print(header, file=fh)
for subgraph in self:
- subgraph.write_classification(fh)
+ if prev_typing:
+ subgraph.compare_classification(prev_typing, fh)
+ else:
+ subgraph.write_classification(fh)
def get_graphs_sorted_by_size(self) -> "ListOfGraphs[BaseGraphType]":
return ListOfGraphs(sorted(self, key=lambda graph: graph.number_of_nodes(), reverse=True))
diff --git a/plasnet/plasmid_graph.py b/plasnet/plasmid_graph.py
index fc38445..5cdad3a 100644
--- a/plasnet/plasmid_graph.py
+++ b/plasnet/plasmid_graph.py
@@ -26,6 +26,7 @@ def build(
distance_filepath: Path,
distance_threshold: float,
plasmids_metadata: list[str],
+ existing_graphs: Optional[tuple] = None,
) -> "PlasmidGraph":
"""
Creates a plasmid graph from plasmid and distance files.
@@ -62,7 +63,9 @@ def build(
""" # noqa: E501
plasmids = pd.read_csv(plasmids_filepath)
- distance_df = pd.read_csv(distance_filepath, sep="\t")
+ distance_df = pd.read_csv(
+ distance_filepath, dtype={"plasmid_1": str, "plasmid_2": str}, sep="\t"
+ )
distance_df[DistanceTags.SplitDistanceTag.value] = distance_df["distance"]
# apply distance threshold
@@ -84,6 +87,10 @@ def build(
create_using=PlasmidGraph,
)
+ if existing_graphs:
+ for existing_graph in existing_graphs:
+ graph = nx.compose(graph, existing_graph)
+
# add all nodes to the graph, including those that have no edges
# possibly add metadata if they were provided
plasmid_metadata_is_too_short = len(plasmids_metadata) < len(plasmids["plasmid"])
diff --git a/plasnet/plasnet_main.py b/plasnet/plasnet_main.py
index e464a43..3fade9a 100644
--- a/plasnet/plasnet_main.py
+++ b/plasnet/plasnet_main.py
@@ -7,6 +7,7 @@
import pandas as pd
from plasnet import __version__
+from plasnet.clustering_dists import all_clustering_dists, make_contingency_matrix, read_in_clusters
from plasnet.communities import Communities
from plasnet.output_producer import OutputProducer
from plasnet.plasmid_graph import PlasmidGraph
@@ -97,6 +98,15 @@ def cli() -> None:
@click.option(
"--plasmids-metadata", type=PathlibPath(exists=True), help="Plasmids metadata text file."
)
+@click.option(
+ "--graph-pickle", multiple=True, help="Existing plasmid graph to append new plasmids to."
+)
+@click.option(
+ "--prev_typing",
+ multiple=True,
+ help="Previous community typing, if appending to an existing plasmid graph.",
+)
+@click.option("--no-community-vis", is_flag=True)
def split(
plasmids: Path,
distances: Path,
@@ -107,13 +117,26 @@ def split(
output_plasmid_graph: bool,
output_type: Optional[str],
plasmids_metadata: Optional[Path],
+ graph_pickle: Optional[tuple],
+ prev_typing: Optional[tuple],
+ no_community_vis: bool,
) -> None:
visualisations_dir = output_dir / "visualisations"
logging.info(f"Creating plasmid graph from {plasmids} and {distances}")
metadata = []
+
if plasmids_metadata:
metadata = plasmids_metadata.read_text().splitlines()
- plasmid_graph = PlasmidGraph.build(plasmids, distances, distance_threshold, metadata)
+ if graph_pickle:
+ existing_graphs = [cast(PlasmidGraph, PlasmidGraph.load(graph)) for graph in graph_pickle]
+ plasmid_graph = PlasmidGraph.build(
+ plasmids, distances, distance_threshold, metadata, existing_graphs
+ )
+ typings = [
+ pd.read_csv(prev, sep="\t", index_col=0).to_dict()["community"] for prev in prev_typing
+ ]
+ else:
+ plasmid_graph = PlasmidGraph.build(plasmids, distances, distance_threshold, metadata)
if output_plasmid_graph:
logging.info("Producing full plasmid graph visualisation")
@@ -126,10 +149,11 @@ def split(
bh_connectivity, bh_neighbours_edge_density
)
- logging.info("Producing communities visualisation")
- OutputProducer.produce_communities_visualisation(
- communities, visualisations_dir / "communities", output_type
- )
+ if not no_community_vis:
+ logging.info("Producing communities visualisation")
+ OutputProducer.produce_communities_visualisation(
+ communities, visualisations_dir / "communities", output_type
+ )
logging.info("Serialising objects")
objects_dir = output_dir / "objects"
@@ -138,6 +162,13 @@ def split(
communities.save(objects_dir / "communities.pkl")
communities.save_graph_as_text(objects_dir / "communities.txt")
communities.save_classification(objects_dir / "communities.tsv", "plasmid\tcommunity")
+ if prev_typing:
+ for i, typing in enumerate(typings):
+ communities.save_classification(
+ objects_dir / f"compare_communities_{i}.tsv",
+ "plasmid\tcommunity\tprevious_community",
+ prev_typing=typing,
+ )
logging.info("All done!")
@@ -191,6 +222,19 @@ def split(
default="html",
help="Whether to output networks as html visualisations, cytoscape formatted json, or both.",
)
+@click.option("--prev_typing", multiple=True, help="Previous subcommunity typing, if it exists.")
+@click.option(
+ "--reclustering_method",
+ type=click.Choice(["unbiased", "biased", "nearest_neighbour"]),
+ default="unbiased",
+ help="unbiased: If including a previous subcommunity typing, all previous and new genomes "
+ "will be reclustered from scratch, ignoring previous typing.\n"
+ "biased: The asynchronous label propagation will start with the previous typing as initial "
+ "labels.\n"
+ "nearest_neighbour: Does not cluster the new genomes, rather, assigns type based on the "
+ "closest neighbour of the previous typing.",
+)
+@click.option("--no-vis", is_flag=True)
def type(
communities_pickle: Path,
distances: Path,
@@ -198,6 +242,9 @@ def type(
distance_threshold: float,
small_subcommunity_size_threshold: int,
output_type: Optional[str],
+ prev_typing: Optional[tuple],
+ reclustering_method: Optional[str],
+ no_vis: bool,
) -> None:
logging.info(f"Loading communities from {communities_pickle}")
communities = cast(Communities, Communities.load(communities_pickle))
@@ -218,26 +265,47 @@ def type(
communities.filter_by_distance(distance_threshold)
logging.info("Typing communities (i.e. splitting them into subcommunities)")
+
+ if prev_typing and reclustering_method == "nearest_neighbour":
+ typing = pd.read_csv(
+ prev_typing[0], sep="\t"
+ ) # nearest neighbour does not support merging graphs
+ elif prev_typing:
+ typings = [
+ pd.read_csv(prev, sep="\t", index_col=0).to_dict()["type"] for prev in prev_typing
+ ]
+
all_subcommunities = Subcommunities()
all_hub_plasmids = set()
for community in communities:
hub_plasmids = community.remove_hub_plasmids()
all_hub_plasmids.update(hub_plasmids)
- subcommunities = community.split_graph_into_subcommunities(
- small_subcommunity_size_threshold
- )
+ if prev_typing and reclustering_method == "biased":
+ subcommunities = community.split_graph_given_labels(
+ small_subcommunity_size_threshold, typings
+ )
+ elif prev_typing and reclustering_method == "nearest_neighbour":
+ new_plasmids = [
+ plasmid for plasmid in community.nodes if plasmid not in typing["plasmid"].to_list()
+ ]
+ subcommunities = community.nearest_neighbour(typing, new_plasmids)
+ else:
+ subcommunities = community.split_graph_into_subcommunities(
+ small_subcommunity_size_threshold
+ )
all_subcommunities.extend(subcommunities)
- logging.info("Producing communities visualisations")
- original_communities.recolour_nodes(communities)
- OutputProducer.produce_communities_visualisation(
- original_communities, output_dir / "visualisations/communities", output_type
- )
+ if not no_vis:
+ logging.info("Producing communities visualisations")
+ original_communities.recolour_nodes(communities)
+ OutputProducer.produce_communities_visualisation(
+ original_communities, output_dir / "visualisations/communities", output_type
+ )
- logging.info("Producing subcommunities visualisations")
- OutputProducer.produce_subcommunities_visualisation(
- all_subcommunities, output_dir / "visualisations/subcommunities", output_type
- )
+ logging.info("Producing subcommunities visualisations")
+ OutputProducer.produce_subcommunities_visualisation(
+ all_subcommunities, output_dir / "visualisations/subcommunities", output_type
+ )
logging.info("Serialising objects")
objects_dir = output_dir / "objects"
@@ -250,6 +318,25 @@ def type(
for plasmid in all_hub_plasmids:
print(plasmid, file=hub_plasmids_fh)
+ if prev_typing and reclustering_method != "nearest_neighbour":
+ for i, typing in enumerate(typings):
+ all_subcommunities.save_classification(
+ objects_dir / f"compare_typing_{i}.tsv",
+ "plasmid\ttype\tprevious_type",
+ prev_typing=typing,
+ )
+
+ clusters_pling, clusters_pling_old, plasmids = read_in_clusters(
+ objects_dir / f"compare_typing_{i}.tsv"
+ )
+ n = len(plasmids)
+ contingency, k_1, k_2 = make_contingency_matrix(clusters_pling, clusters_pling_old)
+ clust_dists = all_clustering_dists(contingency, k_1, k_2, n)
+ with open(objects_dir / f"clustering_dists_{i}.tsv", "w") as f:
+ f.write("distance_type\tdistance\n")
+ for key in clust_dists.keys():
+ f.write(f"{key}\t{clust_dists[key]}\n")
+
logging.info("All done!")