Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ on:
pull_request:
branches: [main]

permissions:
contents: write
pull-requests: write
actions: write

jobs:
main:
runs-on: ${{ matrix.os }}
Expand Down
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[settings]
known_third_party = click,networkx,pandas
known_third_party = click,networkx,numpy,pandas
6 changes: 3 additions & 3 deletions coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
96 changes: 96 additions & 0 deletions plasnet/alt_label_propagation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from collections import Counter

from networkx.utils import groups


def appendable_lpa_communities(G, initial_labels=None, seed=None):
"""Returns communities in `G` as detected by asynchronous label
propagation.

The asynchronous label propagation algorithm is described in
[1]_. The algorithm is probabilistic and the found communities may
vary on different executions.

The algorithm proceeds as follows. After initializing each node with
a unique label, the algorithm repeatedly sets the label of a node to
be the label that appears most frequently among that nodes
neighbors. The algorithm halts when each node has the label that
appears most frequently among its neighbors. The algorithm is
asynchronous because each node is updated without waiting for
updates on the remaining nodes.

This generalized version of the algorithm in [1]_ accepts edge
weights.

Parameters
----------
G : Graph

weight : string
The edge attribute representing the weight of an edge.
If None, each edge is assumed to have weight one. In this
algorithm, the weight of an edge is used in determining the
frequency with which a label appears among the neighbors of a
node: a higher weight means the label appears more often.

seed : integer, random_state, or None (default)
Indicator of random number generation state.
See :ref:`Randomness<randomness>`.

Returns
-------
communities : iterable
Iterable of communities given as sets of nodes.

Notes
-----
Edge weight attributes must be numerical.

References
----------
.. [1] Raghavan, Usha Nandini, Réka Albert, and Soundar Kumara. "Near
linear time algorithm to detect community structures in large-scale
networks." Physical Review E 76.3 (2007): 036106.
"""

if not initial_labels:
labels = {n: i for i, n in enumerate(G)}
else:
start = max(initial_labels.values())
H = G.copy()
H.remove_nodes_from(initial_labels.keys())
labels = {n: i + start for i, n in enumerate(H)}
labels.update(initial_labels)

cont = True

while cont:
cont = False
nodes = list(G)
seed.shuffle(nodes)

for node in nodes:
if not G[node]:
continue

# Get label frequencies among adjacent nodes.
# Depending on the order they are processed in,
# some nodes will be in iteration t and others in t-1,
# making the algorithm asynchronous.
# initialising a Counter from an iterator of labels is
# faster for getting unweighted label frequencies
label_freq = Counter(map(labels.get, G[node]))

# Get the labels that appear with maximum frequency.
max_freq = max(label_freq.values())
best_labels = [label for label, freq in label_freq.items() if freq == max_freq]

# If the node does not have one of the maximum frequency labels,
# randomly choose one of them and update the node's label.
# Continue the iteration as long as at least one node
# doesn't have a maximum frequency label.
if labels[node] not in best_labels:
labels[node] = seed.choice(best_labels)
cont = True

yield from groups(labels).values()
5 changes: 5 additions & 0 deletions plasnet/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,8 @@ def load(cls: Type[BaseGraphType], filepath: Path) -> BaseGraphType:
def write_classification(self, typing_fh: TextIO) -> None:
for node in self.nodes:
typing_fh.write(f"{node}\t{self.label}\n")

def compare_classification(self, prev_typing: dict, typing_fh: TextIO) -> None:
for node in self.nodes:
if node in prev_typing.keys():
typing_fh.write(f"{node}\t{self.label}\t{prev_typing[node]}\n")
133 changes: 133 additions & 0 deletions plasnet/clustering_dists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import numpy as np
import pandas as pd


def read_in_clusters(compare_tsv):
pling_df = pd.read_csv(compare_tsv, sep="\t")
plasmids = list(pling_df["plasmid"].values)
clusters_pling = {
i: set(pling_df[pling_df["type"] == el]["plasmid"].values)
for i, el in enumerate(list(set(pling_df["type"])))
}
clusters_pling_old = {
i: set(pling_df[pling_df["previous_type"] == el]["plasmid"].values)
for i, el in enumerate(list(set(pling_df["previous_type"])))
}
return clusters_pling, clusters_pling_old, plasmids


def make_contingency_matrix(
clusters_1, clusters_2
): # clusters_1 and clusters_2 are dictionaries of clusters,
# k_1 and k_2 the lengths of the respective dictionaries
k_1 = len(clusters_1)
k_2 = len(clusters_2)
contingency = np.zeros((k_1, k_2))
for i in range(k_1):
for j in range(k_2):
contingency[i][j] = len(clusters_1[i].intersection(clusters_2[j]))
return contingency, k_1, k_2


def split_join(contingency, k_1, k_2, n): # clusters_1 and clusters_2 are dictionaries of clusters,
# n is the total number of data points (plasmids)
dist = (
2 * n
- sum([max(contingency[i]) for i in range(k_1)])
- sum([max(contingency[:, j]) for j in range(k_2)])
)
return int(dist)


def rand_index(contingency):
contingency = np.asarray(contingency)

def comb2(x):
return x * (x - 1) / 2.0

n = contingency.sum()
if n <= 1:
return 1.0 # degenerate case

# True positives
tp = np.sum(comb2(contingency))

# Row and column sums
row_sums = contingency.sum(axis=1)
col_sums = contingency.sum(axis=0)

sum_rows = np.sum(comb2(row_sums))
sum_cols = np.sum(comb2(col_sums))

fp = sum_cols - tp
fn = sum_rows - tp

total_pairs = comb2(n)
tn = total_pairs - tp - fp - fn

ri = (tp + tn) / total_pairs
return ri


def adjusted_rand_index(contingency):
# Helper function: n choose 2
def comb2(x):
return x * (x - 1) / 2.0

n = contingency.sum()
if n <= 1:
return 0.0

# Sum over all pairs in cells
sum_comb_cells = np.sum(comb2(contingency))

# Row and column sums
row_sums = contingency.sum(axis=1)
col_sums = contingency.sum(axis=0)

sum_comb_rows = np.sum(comb2(row_sums))
sum_comb_cols = np.sum(comb2(col_sums))

total_pairs = comb2(n)

expected_index = (sum_comb_rows * sum_comb_cols) / total_pairs
max_index = 0.5 * (sum_comb_rows + sum_comb_cols)

denominator = max_index - expected_index
if denominator == 0:
return 0.0

ari = (sum_comb_cells - expected_index) / denominator
return ari


def mutual_information(contingency):
n = contingency.sum()
if n == 0:
return 0.0

row_sums = contingency.sum(axis=1)
col_sums = contingency.sum(axis=0)

# Only consider nonzero entries
nz = contingency > 0
nij = contingency[nz]

# Corresponding row and column sums
i_idx, j_idx = np.nonzero(nz)
ai = row_sums[i_idx]
bj = col_sums[j_idx]

# Compute MI
mi = np.sum((nij / n) * np.log((nij * n) / (ai * bj)))

return mi


def all_clustering_dists(contingency, k_1, k_2, n):
dists = {}
dists["rand index"] = rand_index(contingency)
dists["adjusted rand index"] = adjusted_rand_index(contingency)
dists["mutual information"] = mutual_information(contingency)
dists["split join"] = split_join(contingency, k_1, k_2, n)
return dists
111 changes: 111 additions & 0 deletions plasnet/community_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import networkx as nx

from plasnet.alt_label_propagation import appendable_lpa_communities
from plasnet.ColorPicker import ColorPicker
from plasnet.hub_graph import HubGraph
from plasnet.subcommunities import Subcommunities
Expand Down Expand Up @@ -97,6 +98,116 @@ def split_graph_into_subcommunities(

return Subcommunities(subcommunities)

def split_graph_given_labels(
self, small_subcommunity_size_threshold: int, typings: list[dict]
) -> Subcommunities:
old_plasmids = [
plasmid for typing in typings for plasmid in typing.keys() if plasmid in self.nodes
]
new_plasmids = [plasmid for plasmid in self.nodes if plasmid not in old_plasmids]
new_subcommunities_nodes: list[set[str]] = list(
nx.community.asyn_lpa_communities(G=self.subgraph(new_plasmids), seed=42)
)

label = 0
map = {}
for typing in typings:
for i, subcomm in enumerate(typing.values()):
map[subcomm] = label + i
label = label + len(typing.keys())
initial_labels = {n: map[typing[n]] for n in old_plasmids}
for subcomm in new_subcommunities_nodes:
for plasmid in list(subcomm):
initial_labels[plasmid] = label
label = label + 1
subcommunities_nodes: list[set[str]] = list(
appendable_lpa_communities(G=self, initial_labels=initial_labels, seed=42)
)
subcommunities_nodes = self._fix_small_subcommunities(
subcommunities_nodes, small_subcommunity_size_threshold
)

subcommunities = []
for subcommunity_index, subcommunity_nodes in enumerate(subcommunities_nodes):
colour = ColorPicker.get_color_given_index(subcommunity_index)

subcommunity = SubcommunityGraph(
self.subgraph(subcommunity_nodes),
self._hub_connectivity_threshold,
self._edge_density,
label=f"{self.label}_subcommunity_{subcommunity_index}",
colour=colour,
)
subcommunities.append(subcommunity)

for node in subcommunity_nodes:
self._node_to_colour[node] = colour

return Subcommunities(subcommunities)

def nearest_neighbour(self, typing, new_plasmids) -> Subcommunities:
subcommunity_names = set(typing["type"].to_list())
subcommunity_labels = {
subcomm: [plasmid for plasmid in typing[typing["type"] == subcomm]["plasmid"].values]
for subcomm in list(subcommunity_names)
if subcomm.split("_")[1] == self.label.split("_")[1]
} # select only those that are in this community
max_label = len(subcommunity_labels.keys())

for plasmid in new_plasmids:
if plasmid in self.nodes:
neighbours = [n for n in self[plasmid] if n not in new_plasmids]
if len(neighbours) == 0:
subcommunity_labels[f"community_{self.label}_subcommunity_{max_label}"] = [
plasmid
]
max_label = max_label + 1
else:
neighbours = sorted(
neighbours,
key=lambda n: self.edges[n, plasmid][DistanceTags.SplitDistanceTag.value],
)
min_dist = self.edges[neighbours[0], plasmid][
DistanceTags.SplitDistanceTag.value
]
nearest = [
neighbour
for neighbour in neighbours
if self.edges[neighbour, plasmid][DistanceTags.SplitDistanceTag.value]
== min_dist
]
nearest = sorted(
nearest,
key=lambda n: len(
typing[
typing["type"] == typing[typing["plasmid"] == n]["type"].values[0]
]
),
)
nn = nearest[-1] # select nearest neighbour with largest subcommunity size
subcommunity_labels[typing[typing["plasmid"] == nn]["type"].values[0]].append(
plasmid
)

subcommunities = []
for subcommunity_label in subcommunity_labels.keys():
subcommunity_index = int(subcommunity_label.split("_")[-1])
colour = ColorPicker.get_color_given_index(subcommunity_index)

subcommunity = SubcommunityGraph(
self.subgraph(subcommunity_labels[subcommunity_label]),
self._hub_connectivity_threshold,
self._edge_density,
label=subcommunity_label, # reuse old labels here!
colour=colour,
)
subcommunities.append(subcommunity)

for node in subcommunity_labels[subcommunity_label]:
self._node_to_colour[node] = colour

return Subcommunities(subcommunities)

def _get_libs_relative_path(self) -> str:
return ".."

Expand Down
Loading