Wasserstein Network (wnet) is a Python/C++ library for working with Wasserstein distances. It uses the Min Cost Flow algorithm as implemented by the LEMON library, exposed to Python via the pylmcf module, enabling efficient computation and manipulation of Wasserstein distances between multidimensional distributions.
- Wasserstein and Truncated Wasserstein distance between multidimensional distributions (dimensions 1–20)
- Three distance metrics: L1, L2, L∞
- Derivatives with respect to peak intensities and spectrum mixture proportions
- Position gradients (∂cost/∂position) with warm-restart re-solving after peak position updates
- Support for distribution mixtures and efficient recalculation with changed mixture proportions
- Picklable
Distributionobjects
You can install the Python package using pip:
pip install wnetimport numpy as np
from wnet import WassersteinDistance, Distribution
from wnet.distances import DistanceMetric
positions1 = np.array([[0, 1, 5, 10], [0, 0, 0, 3]])
intensities1 = np.array([10, 5, 5, 5])
positions2 = np.array([[1, 10], [0, 0]])
intensities2 = np.array([20, 5])
S1 = Distribution(positions1, intensities1)
S2 = Distribution(positions2, intensities2)
print(WassersteinDistance(S1, S2, DistanceMetric.L1))
# 45Mass that cannot be matched within max_distance is discarded at a fixed cost rather than transported arbitrarily far:
from wnet import TruncatedWassersteinDistance
print(TruncatedWassersteinDistance(S1, S2, DistanceMetric.L2, max_distance=3.0))signal_part_derivatives() returns the marginal cost of increasing each theoretical peak's intensity by 1 — useful for scoring how well each peak is explained:
from wnet import WassersteinNetwork
W = WassersteinNetwork(S1, [S2], DistanceMetric.L2, max_distance=10.0)
W.build()
W.solve()
derivs = W.signal_part_derivatives() # np.ndarray, one value per peak in S1After an initial solve, positions can be updated and re-solved cheaply via a warm restart. update_positions_and_get_gradient() returns ∂cost/∂position for all peaks so you can feed them directly into a gradient-based optimiser:
W = WassersteinNetwork(S1, [S2], DistanceMetric.L2, max_distance=10.0)
W.build()
W.solve()
for _ in range(100):
grad_empirical, grad_theoretical = W.update_positions_and_get_gradient(new_positions)
new_positions -= 0.01 * grad_empiricalMIT Licence
If you use this software, please cite:
Król J, Bochenek M, Jopa S, Kazimierczuk K, Gambin A, Startek MP (2026). WNetAlign: fast and accurate spectra alignment using truncated Wasserstein distance and network simplex. Briefings in Bioinformatics, 27(3), bbag247. https://doi.org/10.1093/bib/bbag247
@article{krol2026wnetalign,
title = {WNetAlign: fast and accurate spectra alignment using truncated Wasserstein distance and network simplex},
author = {Kr{\'o}l, Justyna and Bochenek, Maria and Jopa, Sylwia and Kazimierczuk, Krzysztof and Gambin, Anna and Startek, Micha{\l} Piotr},
journal = {Briefings in Bioinformatics},
volume = {27},
number = {3},
pages = {bbag247},
year = {2026},
doi = {10.1093/bib/bbag247}
}