Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .lp import (
emd,
emd2,
emd2_lazy,
emd_1d,
emd2_1d,
wasserstein_1d,
Expand Down Expand Up @@ -82,6 +83,7 @@
__all__ = [
"emd",
"emd2",
"emd2_lazy",
"emd_1d",
"sinkhorn",
"sinkhorn2",
Expand Down
25 changes: 16 additions & 9 deletions ot/gromov/_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def GW_distance_estimation(

for i in range(nb_samples_p):
if nx.issparse(T):
T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,))
T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,))
T_indexi = nx.reshape(nx.todense(T[[index_i[i]], :]), (-1,))
T_indexj = nx.reshape(nx.todense(T[[index_j[i]], :]), (-1,))
else:
T_indexi = T[index_i[i], :]
T_indexj = T[index_j[i], :]
Expand Down Expand Up @@ -243,16 +243,18 @@ def pointwise_gromov_wasserstein(
index = np.zeros(2, dtype=int)

# Initialize with default marginal
index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q))
index[0] = int(generator.choice(len_p, size=1, p=nx.to_numpy(p)).item())
index[1] = int(generator.choice(len_q, size=1, p=nx.to_numpy(q)).item())
T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))

best_gw_dist_estimated = np.inf
for cpt in range(max_iter):
index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
index[1] = generator.choice(
len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
index[0] = int(generator.choice(len_p, size=1, p=nx.to_numpy(p)).item())
T_index0 = nx.reshape(nx.todense(T[[index[0]], :]), (-1,))
index[1] = int(
generator.choice(
len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
).item()
)

if alpha == 1:
Expand Down Expand Up @@ -404,10 +406,15 @@ def sampled_gromov_wasserstein(
)
Lik = 0
for i, index0_i in enumerate(index0):
T_row = (
nx.reshape(nx.todense(T[[index0_i], :]), (-1,))
if nx.issparse(T)
else T[index0_i, :]
)
index1 = generator.choice(
len_q,
size=nb_samples_grad_q,
p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])),
p=nx.to_numpy(T_row / nx.sum(T_row)),
replace=False,
)
# If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
Expand Down
16 changes: 16 additions & 0 deletions ot/lp/EMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,21 @@ int EMD_wrap_sparse(
uint64_t maxIter // Maximum iterations for solver
);

int EMD_wrap_lazy(
int n1, // Number of source points
int n2, // Number of target points
double *X, // Source weights (n1)
double *Y, // Target weights (n2)
double *coords_a, // Source coordinates (n1 x dim)
double *coords_b, // Target coordinates (n2 x dim)
int dim, // Dimension of coordinates
int metric, // Distance metric: 0=sqeuclidean, 1=euclidean, 2=cityblock
double *G, // Output: transport plan (n1 x n2)
double *alpha, // Output: dual variables for sources (n1)
double *beta, // Output: dual variables for targets (n2)
double *cost, // Output: total transportation cost
uint64_t maxIter // Maximum iterations for solver
);


#endif
106 changes: 105 additions & 1 deletion ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,4 +370,108 @@ int EMD_wrap_sparse(
}
}
return ret;
}
}

int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b,
int dim, int metric, double *G, double *alpha, double *beta,
double *cost, uint64_t maxIter) {
using namespace lemon;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);

// Filter source nodes with non-zero weights
std::vector<int> idx_a;
std::vector<double> weights_a_filtered;
std::vector<double> coords_a_filtered;

// Reserve space to avoid reallocations
idx_a.reserve(n1);
weights_a_filtered.reserve(n1);
coords_a_filtered.reserve(n1 * dim);

for (int i = 0; i < n1; i++) {
if (X[i] > 0) {
idx_a.push_back(i);
weights_a_filtered.push_back(X[i]);
for (int d = 0; d < dim; d++) {
coords_a_filtered.push_back(coords_a[i * dim + d]);
}
}
}
int n = idx_a.size();

// Filter target nodes with non-zero weights
std::vector<int> idx_b;
std::vector<double> weights_b_filtered;
std::vector<double> coords_b_filtered;

// Reserve space to avoid reallocations
idx_b.reserve(n2);
weights_b_filtered.reserve(n2);
coords_b_filtered.reserve(n2 * dim);

for (int j = 0; j < n2; j++) {
if (Y[j] > 0) {
idx_b.push_back(j);
weights_b_filtered.push_back(-Y[j]); // Demand is negative supply
for (int d = 0; d < dim; d++) {
coords_b_filtered.push_back(coords_b[j * dim + d]);
}
}
}
int m = idx_b.size();

if (n == 0 || m == 0) {
*cost = 0.0;
return 0;
}

// Create full bipartite graph
Digraph di(n, m);

NetworkSimplexSimple<Digraph, double, double, node_id_type> net(
di, true, (int)(n + m), (uint64_t)(n) * (uint64_t)(m), maxIter
);

// Set supplies
net.supplyMap(&weights_a_filtered[0], n, &weights_b_filtered[0], m);

// Enable lazy cost computation - costs will be computed on-the-fly
net.setLazyCost(&coords_a_filtered[0], &coords_b_filtered[0], dim, metric, n, m);

// Run solver
int ret = net.run();

if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) {
*cost = 0;

// Initialize output arrays
for (int i = 0; i < n1 * n2; i++) G[i] = 0.0;
for (int i = 0; i < n1; i++) alpha[i] = 0.0;
for (int i = 0; i < n2; i++) beta[i] = 0.0;

// Extract solution
Arc a;
di.first(a);
for (; a != INVALID; di.next(a)) {
int i = di.source(a);
int j = di.target(a) - n;

int orig_i = idx_a[i];
int orig_j = idx_b[j];

double flow = net.flow(a);
G[orig_i * n2 + orig_j] = flow;

alpha[orig_i] = -net.potential(i);
beta[orig_j] = net.potential(j + n);

if (flow > 0) {
double c = net.computeLazyCost(i, j);
*cost += flow * c;
}
}
}

return ret;
}
3 changes: 2 additions & 1 deletion ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# License: MIT License

from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize
from ._network_simplex import emd, emd2
from ._network_simplex import emd, emd2, emd2_lazy
from ._barycenter_solvers import (
barycenter,
free_support_barycenter,
Expand All @@ -35,6 +35,7 @@
__all__ = [
"emd",
"emd2",
"emd2_lazy",
"barycenter",
"free_support_barycenter",
"cvx",
Expand Down
Loading
Loading