diff --git a/ot/__init__.py b/ot/__init__.py index 8a389bb98..26f428aa1 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -41,6 +41,7 @@ from .lp import ( emd, emd2, + emd2_lazy, emd_1d, emd2_1d, wasserstein_1d, @@ -82,6 +83,7 @@ __all__ = [ "emd", "emd2", + "emd2_lazy", "emd_1d", "sinkhorn", "sinkhorn2", diff --git a/ot/gromov/_estimators.py b/ot/gromov/_estimators.py index 14871bfe3..18359afe5 100644 --- a/ot/gromov/_estimators.py +++ b/ot/gromov/_estimators.py @@ -122,8 +122,8 @@ def GW_distance_estimation( for i in range(nb_samples_p): if nx.issparse(T): - T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,)) - T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,)) + T_indexi = nx.reshape(nx.todense(T[[index_i[i]], :]), (-1,)) + T_indexj = nx.reshape(nx.todense(T[[index_j[i]], :]), (-1,)) else: T_indexi = T[index_i[i], :] T_indexj = T[index_j[i], :] @@ -243,16 +243,18 @@ def pointwise_gromov_wasserstein( index = np.zeros(2, dtype=int) # Initialize with default marginal - index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) - index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q)) + index[0] = int(generator.choice(len_p, size=1, p=nx.to_numpy(p)).item()) + index[1] = int(generator.choice(len_q, size=1, p=nx.to_numpy(q)).item()) T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) best_gw_dist_estimated = np.inf for cpt in range(max_iter): - index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) - T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,)) - index[1] = generator.choice( - len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0)) + index[0] = int(generator.choice(len_p, size=1, p=nx.to_numpy(p)).item()) + T_index0 = nx.reshape(nx.todense(T[[index[0]], :]), (-1,)) + index[1] = int( + generator.choice( + len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0)) + ).item() ) if alpha == 1: @@ -404,10 +406,15 @@ def sampled_gromov_wasserstein( ) Lik = 0 for i, index0_i in enumerate(index0): + T_row = ( + nx.reshape(nx.todense(T[[index0_i], :]), (-1,)) + if nx.issparse(T) + else T[index0_i, :] + ) index1 = generator.choice( len_q, size=nb_samples_grad_q, - p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])), + p=nx.to_numpy(T_row / nx.sum(T_row)), replace=False, ) # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index e3564a2d2..6f408ffeb 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -51,5 +51,21 @@ int EMD_wrap_sparse( uint64_t maxIter // Maximum iterations for solver ); +int EMD_wrap_lazy( + int n1, // Number of source points + int n2, // Number of target points + double *X, // Source weights (n1) + double *Y, // Target weights (n2) + double *coords_a, // Source coordinates (n1 x dim) + double *coords_b, // Target coordinates (n2 x dim) + int dim, // Dimension of coordinates + int metric, // Distance metric: 0=sqeuclidean, 1=euclidean, 2=cityblock + double *G, // Output: transport plan (n1 x n2) + double *alpha, // Output: dual variables for sources (n1) + double *beta, // Output: dual variables for targets (n2) + double *cost, // Output: total transportation cost + uint64_t maxIter // Maximum iterations for solver +); + #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index bd3672535..6aa27897a 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -370,4 +370,108 @@ int EMD_wrap_sparse( } } return ret; -} \ No newline at end of file +} + +int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, + int dim, int metric, double *G, double *alpha, double *beta, + double *cost, uint64_t maxIter) { + using namespace lemon; + typedef FullBipartiteDigraph Digraph; + DIGRAPH_TYPEDEFS(Digraph); + + // Filter source nodes with non-zero weights + std::vector idx_a; + std::vector weights_a_filtered; + std::vector coords_a_filtered; + + // Reserve space to avoid reallocations + idx_a.reserve(n1); + weights_a_filtered.reserve(n1); + coords_a_filtered.reserve(n1 * dim); + + for (int i = 0; i < n1; i++) { + if (X[i] > 0) { + idx_a.push_back(i); + weights_a_filtered.push_back(X[i]); + for (int d = 0; d < dim; d++) { + coords_a_filtered.push_back(coords_a[i * dim + d]); + } + } + } + int n = idx_a.size(); + + // Filter target nodes with non-zero weights + std::vector idx_b; + std::vector weights_b_filtered; + std::vector coords_b_filtered; + + // Reserve space to avoid reallocations + idx_b.reserve(n2); + weights_b_filtered.reserve(n2); + coords_b_filtered.reserve(n2 * dim); + + for (int j = 0; j < n2; j++) { + if (Y[j] > 0) { + idx_b.push_back(j); + weights_b_filtered.push_back(-Y[j]); // Demand is negative supply + for (int d = 0; d < dim; d++) { + coords_b_filtered.push_back(coords_b[j * dim + d]); + } + } + } + int m = idx_b.size(); + + if (n == 0 || m == 0) { + *cost = 0.0; + return 0; + } + + // Create full bipartite graph + Digraph di(n, m); + + NetworkSimplexSimple net( + di, true, (int)(n + m), (uint64_t)(n) * (uint64_t)(m), maxIter + ); + + // Set supplies + net.supplyMap(&weights_a_filtered[0], n, &weights_b_filtered[0], m); + + // Enable lazy cost computation - costs will be computed on-the-fly + net.setLazyCost(&coords_a_filtered[0], &coords_b_filtered[0], dim, metric, n, m); + + // Run solver + int ret = net.run(); + + if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) { + *cost = 0; + + // Initialize output arrays + for (int i = 0; i < n1 * n2; i++) G[i] = 0.0; + for (int i = 0; i < n1; i++) alpha[i] = 0.0; + for (int i = 0; i < n2; i++) beta[i] = 0.0; + + // Extract solution + Arc a; + di.first(a); + for (; a != INVALID; di.next(a)) { + int i = di.source(a); + int j = di.target(a) - n; + + int orig_i = idx_a[i]; + int orig_j = idx_b[j]; + + double flow = net.flow(a); + G[orig_i * n2 + orig_j] = flow; + + alpha[orig_i] = -net.potential(i); + beta[orig_j] = net.potential(j + n); + + if (flow > 0) { + double c = net.computeLazyCost(i, j); + *cost += flow * c; + } + } + } + + return ret; +} diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index f8924a322..8e88d63c8 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -9,7 +9,7 @@ # License: MIT License from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize -from ._network_simplex import emd, emd2 +from ._network_simplex import emd, emd2, emd2_lazy from ._barycenter_solvers import ( barycenter, free_support_barycenter, @@ -35,6 +35,7 @@ __all__ = [ "emd", "emd2", + "emd2_lazy", "barycenter", "free_support_barycenter", "cvx", diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index d4dfa1ec3..7cce0113c 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -13,7 +13,7 @@ from ..utils import list_to_array, check_number_threads from ..backend import get_backend -from .emd_wrap import emd_c, emd_c_sparse, check_result +from .emd_wrap import emd_c, emd_c_sparse, emd_c_lazy, check_result def center_ot_dual(alpha0, beta0, a=None, b=None): @@ -320,20 +320,20 @@ def emd( if edge_costs.dtype != np.float64: edge_costs = edge_costs.astype(np.float64) - if len(a) != 0: + if a is not None and len(a) != 0: type_as = a - elif len(b) != 0: + elif b is not None and len(b) != 0: type_as = b else: - type_as = a + type_as = a if a is not None else b # Set n1, n2 if not already set (dense case) if n1 is None: n1, n2 = M.shape - if len(a) == 0: + if a is None or len(a) == 0: a = nx.ones((n1,), type_as=type_as) / n1 - if len(b) == 0: + if b is None or len(b) == 0: b = nx.ones((n2,), type_as=type_as) / n2 if is_sparse: @@ -471,7 +471,7 @@ def emd2( .. note:: This function will cast the computed transport plan and transportation loss to the data type of the provided input with the - following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`, + following priority : :math:`\mathbf{a}`, then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided. Casting to an integer tensor might result in a loss of precision. If this behaviour is unwanted, please make sure to provide a @@ -591,21 +591,21 @@ def emd2( if edge_costs.dtype != np.float64: edge_costs = edge_costs.astype(np.float64) - if len(a) != 0: + if a is not None and len(a) != 0: type_as = a - elif len(b) != 0: + elif b is not None and len(b) != 0: type_as = b else: - type_as = a + type_as = a if a is not None else b # Set n1, n2 if not already set (dense case) if n1 is None: n1, n2 = M.shape # if empty array given then use uniform distributions - if len(a) == 0: + if a is None or len(a) == 0: a = nx.ones((n1,), type_as=type_as) / n1 - if len(b) == 0: + if b is None or len(b) == 0: b = nx.ones((n2,), type_as=type_as) / n2 a0, b0 = a, b @@ -775,3 +775,188 @@ def f(b): res = list(map(f, [b[:, i].copy() for i in range(nb)])) return res + + +def emd2_lazy( + a, + b, + X_a, + X_b, + metric="sqeuclidean", + numItermax=100000, + log=False, + return_matrix=False, + center_dual=True, + check_marginals=True, +): + r"""Solves the Earth Movers distance problem with lazy cost computation and returns the loss + + .. math:: + \min_\gamma \quad \langle \gamma, \mathbf{M}(\mathbf{X}_a, \mathbf{X}_b) \rangle_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`\mathbf{M}(\mathbf{X}_a, \mathbf{X}_b)` is computed on-the-fly from coordinates + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + .. note:: This function computes distances on-the-fly during the network simplex algorithm, + avoiding the O(ns*nt) memory cost of pre-computing the full cost matrix. Memory usage + is O(ns+nt) instead. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Parameters + ---------- + a : (ns,) array-like, float64 + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float64 + Target histogram (uniform weight if empty list) + X_a : (ns, dim) array-like, float64 + Source sample coordinates + X_b : (nt, dim) array-like, float64 + Target sample coordinates + metric : str, optional (default='sqeuclidean') + Distance metric for cost computation. Options: + + - 'sqeuclidean': Squared Euclidean distance + - 'euclidean': Euclidean distance + - 'cityblock': Manhattan/L1 distance + + numItermax : int, optional (default=100000) + Maximum number of iterations before stopping if not converged + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost, dual variables, + and optionally the transport plan (sparse format) + return_matrix: boolean, optional (default=False) + If True, returns the optimal transportation matrix in the log (sparse format) + center_dual: boolean, optional (default=True) + If True, centers the dual potential using :py:func:`ot.lp.center_ot_dual` + check_marginals: bool, optional (default=True) + If True, checks that the marginals mass are equal + + Returns + ------- + W: float + Optimal transportation loss + log: dict + If input log is True, a dictionary containing: + + - cost: the optimal transportation cost + - u, v: dual variables + - warning: solver status message + - result_code: solver return code + - G: (optional) sparse transport plan if return_matrix=True + + See Also + -------- + ot.emd2 : EMD with pre-computed cost matrix + ot.lp.emd_c_lazy : Low-level C++ lazy solver + """ + + a, b, X_a, X_b = list_to_array(a, b, X_a, X_b) + nx = get_backend(a, b, X_a, X_b) + + n1, n2 = X_a.shape[0], X_b.shape[0] + + # Validate dimensions match + if X_a.shape[1] != X_b.shape[1]: + raise ValueError( + f"X_a and X_b must have the same number of dimensions, " + f"got {X_a.shape[1]} and {X_b.shape[1]}" + ) + + if a is not None and len(a) != 0: + type_as = a + elif b is not None and len(b) != 0: + type_as = b + else: + type_as = X_a + + # if empty array given then use uniform distributions + if a is None or len(a) == 0: + a = nx.ones((n1,), type_as=type_as) / n1 + if b is None or len(b) == 0: + b = nx.ones((n2,), type_as=type_as) / n2 + + a0, b0 = a, b + + # Convert to numpy for C++ backend + X_a_np = nx.to_numpy(X_a) + X_b_np = nx.to_numpy(X_b) + a_np = nx.to_numpy(a) + b_np = nx.to_numpy(b) + + X_a_np = np.asarray(X_a_np, dtype=np.float64, order="C") + X_b_np = np.asarray(X_b_np, dtype=np.float64, order="C") + a_np = np.asarray(a_np, dtype=np.float64) + b_np = np.asarray(b_np, dtype=np.float64) + + assert ( + a_np.shape[0] == n1 and b_np.shape[0] == n2 + ), "Dimension mismatch, check dimensions of X_a/X_b with a and b" + + # ensure that same mass + if check_marginals: + np.testing.assert_almost_equal( + a_np.sum(0), + b_np.sum(0, keepdims=True), + err_msg="a and b vector must have the same sum", + decimal=6, + ) + b_np = b_np * a_np.sum(0) / b_np.sum(0, keepdims=True) + + # Solve with lazy cost computation + G, cost, u, v, result_code = emd_c_lazy( + a_np, b_np, X_a_np, X_b_np, metric, numItermax + ) + + # Center dual potentials + if center_dual: + u, v = center_ot_dual(u, v, a_np, b_np) + + # Convert sparse plan to backend format + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + + G_backend = nx.from_numpy(G, type_as=type_as) + + # Set gradients wrt marginals + cost_backend = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + ), + ) + + check_result(result_code) + + # Return results + if log or return_matrix: + log_dict = { + "cost": cost_backend, + "u": nx.from_numpy(u, type_as=type_as), + "v": nx.from_numpy(v, type_as=type_as), + "warning": check_result(result_code), + "result_code": result_code, + } + if return_matrix: + log_dict["G"] = G_backend + return cost_backend, log_dict + else: + return cost_backend diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 4ce315f5f..2b2a5a1ef 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -23,6 +23,7 @@ cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil + int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -285,4 +286,39 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, flow_targets = flow_targets[:n_flows_out] flow_values = flow_values[:n_flows_out] - return flow_sources, flow_targets, flow_values, cost, alpha, beta, result_code \ No newline at end of file + return flow_sources, flow_targets, flow_values, cost, alpha, beta, result_code + + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] coords_a, np.ndarray[double, ndim=2, mode="c"] coords_b, str metric='sqeuclidean', uint64_t max_iter=100000): + """Solves the Earth Movers distance problem with lazy cost computation from coordinates.""" + cdef int n1 = coords_a.shape[0] + cdef int n2 = coords_b.shape[0] + cdef int dim = coords_a.shape[1] + cdef int result_code = 0 + cdef double cost = 0 + cdef int metric_code + + # Validate dimension consistency + if coords_b.shape[1] != dim: + raise ValueError(f"Coordinate dimension mismatch: coords_a has {dim} dimensions but coords_b has {coords_b.shape[1]}") + + if metric == 'sqeuclidean': + metric_code = 0 + elif metric == 'euclidean': + metric_code = 1 + elif metric == 'cityblock': + metric_code = 2 + else: + raise ValueError(f"Unknown metric: {metric}") + cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2) + cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros([n1, n2]) + if not len(a): + a = np.ones((n1,)) / n1 + if not len(b): + b = np.ones((n2,)) / n2 + with nogil: + result_code = EMD_wrap_lazy(n1, n2, a.data, b.data, coords_a.data, coords_b.data, dim, metric_code, G.data, alpha.data, beta.data, &cost, max_iter) + return G, cost, alpha, beta, result_code diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 9612a8a24..1566e74e0 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -27,18 +27,10 @@ #pragma once #undef DEBUG_LVL -#define DEBUG_LVL 0 - -#if DEBUG_LVL>0 -#include -#endif - #undef EPSILON #undef _EPSILON -#undef MAX_DEBUG_ITER #define EPSILON 2.2204460492503131e-15 #define _EPSILON 1e-8 -#define MAX_DEBUG_ITER 100000 /// \ingroup min_cost_flow_algs @@ -238,7 +230,8 @@ namespace lemon { _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), MAX(std::numeric_limits::max()), INF(std::numeric_limits::has_infinity ? - std::numeric_limits::infinity() : MAX) + std::numeric_limits::infinity() : MAX), + _lazy_cost(false), _coords_a(nullptr), _coords_b(nullptr), _dim(0), _metric(0), _n1(0), _n2(0) { // Reset data structures reset(); @@ -320,6 +313,8 @@ namespace lemon { // Data related to the underlying digraph const GR &_graph; int _node_num; + int _n1; // Number of source nodes (for lazy cost computation) + int _n2; // Number of target nodes (for lazy cost computation) ArcsType _arc_num; ArcsType _all_arc_num; ArcsType _search_arc_num; @@ -342,6 +337,12 @@ namespace lemon { //SparseValueVector _flow; CostVector _pi; + // Lazy cost computation support + bool _lazy_cost; + const double* _coords_a; + const double* _coords_b; + int _dim; + int _metric; // 0: sqeuclidean, 1: euclidean, 2: cityblock private: // Data for storing the spanning tree structure @@ -470,6 +471,41 @@ namespace lemon { _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE); } + // Get cost for an arc (either from pre-computed array or compute lazily) + inline Cost getCost(ArcsType e) const { + if (!_ns._lazy_cost) { + return _cost[e]; + } else { + // For lazy mode, compute cost from coordinates inline + // _source and _target use reversed node numbering + int i = _ns._node_num - _source[e] - 1; + int j = _ns._n2 - _target[e] - 1; + + const double* __restrict__ xa = _ns._coords_a + i * _ns._dim; + const double* __restrict__ xb = _ns._coords_b + j * _ns._dim; + Cost cost = 0; + + if (_ns._metric == 0) { // sqeuclidean + for (int d = 0; d < _ns._dim; ++d) { + Cost diff = xa[d] - xb[d]; + cost += diff * diff; + } + return cost; + } else if (_ns._metric == 1) { // euclidean + for (int d = 0; d < _ns._dim; ++d) { + Cost diff = xa[d] - xb[d]; + cost += diff * diff; + } + return std::sqrt(cost); + } else { // cityblock + for (int d = 0; d < _ns._dim; ++d) { + cost += std::abs(xa[d] - xb[d]); + } + return cost; + } + } + } + // Find next entering arc bool findEnteringArc() { Cost c, min = 0; @@ -477,33 +513,33 @@ namespace lemon { ArcsType cnt = _block_size; double a; for (e = _next_arc; e != _search_arc_num; ++e) { - c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + c = _state[e] * (getCost(e) + _pi[_source[e]] - _pi[_target[e]]); if (c < min) { min = c; _in_arc = e; } if (--cnt == 0) { a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); - a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]); + a=a>fabs(getCost(_in_arc))?a:fabs(getCost(_in_arc)); if (min < -EPSILON*a) goto search_end; cnt = _block_size; } } for (e = 0; e != _next_arc; ++e) { - c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + c = _state[e] * (getCost(e) + _pi[_source[e]] - _pi[_target[e]]); if (c < min) { min = c; _in_arc = e; } if (--cnt == 0) { a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); - a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]); + a=a>fabs(getCost(_in_arc))?a:fabs(getCost(_in_arc)); if (min < -EPSILON*a) goto search_end; cnt = _block_size; } } a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); - a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]); + a=a>fabs(getCost(_in_arc))?a:fabs(getCost(_in_arc)); if (min >= -EPSILON*a) return false; search_end: @@ -565,6 +601,90 @@ namespace lemon { return *this; } + /// \brief Enable lazy cost computation from coordinates. + /// + /// This function enables lazy cost computation where distances are + /// computed on-the-fly from point coordinates instead of using a + /// pre-computed cost matrix. + /// + /// \param coords_a Pointer to source coordinates (n1 x dim array) + /// \param coords_b Pointer to target coordinates (n2 x dim array) + /// \param dim Dimension of the coordinates + /// \param metric Distance metric: 0=sqeuclidean, 1=euclidean, 2=cityblock + /// + /// \return (*this) + NetworkSimplexSimple& setLazyCost(const double* coords_a, const double* coords_b, + int dim, int metric, int n1, int n2) { + _lazy_cost = true; + _coords_a = coords_a; + _coords_b = coords_b; + _dim = dim; + _metric = metric; + _n1 = n1; + _n2 = n2; + return *this; + } + + /// \brief Compute cost lazily from coordinates. + /// + /// Computes the distance between source node i and target node j + /// based on the specified metric. + /// + /// \param i Source node index + /// \param j Target node index (adjusted by subtracting n1) + /// + /// \return Cost (distance) between the two points + inline Cost computeLazyCost(int i, int j_adjusted) const { + const double* __restrict__ xa = _coords_a + i * _dim; + const double* __restrict__ xb = _coords_b + j_adjusted * _dim; + Cost cost = 0; + + if (_metric == 0) { // sqeuclidean + for (int d = 0; d < _dim; ++d) { + Cost diff = xa[d] - xb[d]; + cost += diff * diff; + } + return cost; + } else if (_metric == 1) { // euclidean + for (int d = 0; d < _dim; ++d) { + Cost diff = xa[d] - xb[d]; + cost += diff * diff; + } + return std::sqrt(cost); + } else { // cityblock (L1) + for (int d = 0; d < _dim; ++d) { + cost += std::abs(xa[d] - xb[d]); + } + return cost; + } + } + + + /// \brief Get cost for an arc (either from array or compute lazily). + /// + /// This is the main cost accessor that works from anywhere in the class. + /// In lazy mode, computes cost on-the-fly from coordinates. + /// In normal mode, returns pre-computed cost from array. + /// + /// \param arc_id The arc ID + /// \return Cost of the arc + inline Cost getCostForArc(ArcsType arc_id) const { + if (!_lazy_cost) { + return _cost[arc_id]; + } else { + // For artificial arcs (>= _arc_num), return 0 + // These are not real transport arcs + if (arc_id >= _arc_num) { + return 0; + } + // Compute lazily from coordinates + // _source and _target use reversed node numbering: _node_id(n) = _node_num - n - 1 + // Recover original indices: i = _node_num - _source[arc_id] - 1, j = _n2 - _target[arc_id] - 1 + int i = _node_num - _source[arc_id] - 1; + int j = _n2 - _target[arc_id] - 1; + return computeLazyCost(i, j); + } + } /// \brief Set the supply values of the nodes. /// @@ -689,14 +809,7 @@ namespace lemon { /// \see ProblemType, PivotRule /// \see resetParams(), reset() ProblemType run() { -#if DEBUG_LVL>0 - std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ; -#endif - if (!init()) return INFEASIBLE; -#if DEBUG_LVL>0 - std::cout << "Init done, starting iterations\n"; -#endif return start(); } @@ -879,8 +992,19 @@ namespace lemon { c += Number(it->second) * Number(_cost[it->first]); return c;*/ - for (ArcsType i=0; i<_flow.size(); i++) - c += _flow[i] * Number(_cost[i]); + if (!_lazy_cost) { + for (ArcsType i=0; i<_flow.size(); i++) + c += _flow[i] * Number(_cost[i]); + } else { + // Compute costs lazily + for (ArcsType i=0; i<_flow.size(); i++) { + if (_flow[i] != 0) { + int src = _node_num - _source[i] - 1; + int tgt = _n2 - _target[i] - 1; + c += _flow[i] * Number(computeLazyCost(src, tgt)); + } + } + } return c; } @@ -965,7 +1089,8 @@ namespace lemon { } else { ART_COST = 0; for (ArcsType i = 0; i != _arc_num; ++i) { - if (_cost[i] > ART_COST) ART_COST = _cost[i]; + Cost c = getCostForArc(i); + if (c > ART_COST) ART_COST = c; } ART_COST = (ART_COST + 1) * _node_num; } @@ -1305,8 +1430,8 @@ namespace lemon { // Update potentials void updatePotential() { Cost sigma = _forward[u_in] ? - _pi[v_in] - _pi[u_in] - _cost[_pred[u_in]] : - _pi[v_in] - _pi[u_in] + _cost[_pred[u_in]]; + _pi[v_in] - _pi[u_in] - getCostForArc(_pred[u_in]) : + _pi[v_in] - _pi[u_in] + getCostForArc(_pred[u_in]); // Update potentials in the subtree, which has been moved int end = _thread[_last_succ[u_in]]; for (int u = u_in; u != end; u = _thread[u]) { @@ -1365,7 +1490,7 @@ namespace lemon { Arc min_arc = INVALID; Arc a; _graph.firstIn(a, v); for (; a != INVALID; _graph.nextIn(a)) { - c = _cost[getArcID(a)]; + c = getCostForArc(getArcID(a)); if (c < min_cost) { min_cost = c; min_arc = a; @@ -1384,7 +1509,7 @@ namespace lemon { Arc min_arc = INVALID; Arc a; _graph.firstOut(a, u); for (; a != INVALID; _graph.nextOut(a)) { - c = _cost[getArcID(a)]; + c = getCostForArc(getArcID(a)); if (c < min_cost) { min_cost = c; min_arc = a; @@ -1400,7 +1525,7 @@ namespace lemon { for (ArcsType i = 0; i != arc_vector.size(); ++i) { in_arc = arc_vector[i]; // l'erreur est probablement ici... - if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] - + if (_state[in_arc] * (getCostForArc(in_arc) + _pi[_source[in_arc]] - _pi[_target[in_arc]]) >= 0) continue; findJoinNode(); bool change = findLeavingArc(); @@ -1436,27 +1561,6 @@ namespace lemon { retVal = MAX_ITER_REACHED; break; } -#if DEBUG_LVL>0 - if(iter_number>MAX_DEBUG_ITER) - break; - if(iter_number%1000==0||iter_number%1000==1){ - double curCost=totalCost(); - double sumFlow=0; - double a; - a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); - a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); - for (int64_t i=0; i<_flow.size(); i++) { - sumFlow+=_state[i]*_flow[i]; - } - std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; - std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; - std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; - std::cout << _cost[in_arc] << "\n"; - std::cout << _pi[_source[in_arc]] << "\n"; - std::cout << _pi[_target[in_arc]] << "\n"; - std::cout << a << "\n"; - } -#endif findJoinNode(); bool change = findLeavingArc(); @@ -1466,45 +1570,9 @@ namespace lemon { updateTreeStructure(); updatePotential(); } -#if DEBUG_LVL>0 - else{ - std::cout << "No change\n"; - } -#endif -#if DEBUG_LVL>1 - std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n"; -#endif } - -#if DEBUG_LVL>0 - double curCost=totalCost(); - double sumFlow=0; - double a; - a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); - a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); - for (int64_t i=0; i<_flow.size(); i++) { - sumFlow+=_state[i]*_flow[i]; - } - - std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; - - std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; - std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; - -#endif - -#if DEBUG_LVL>1 - sumFlow=0; - for (int i=0; i<_flow.size(); i++) { - sumFlow+=_state[i]*_flow[i]; - if (_state[i]==STATE_TREE) { - std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n"; - } - } - std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n"; -#endif // Check feasibility if( retVal == OPTIMAL){ for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) { diff --git a/ot/regpath.py b/ot/regpath.py index e64ca7c77..aedc35b88 100644 --- a/ot/regpath.py +++ b/ot/regpath.py @@ -486,6 +486,9 @@ def complement_schur(M_current, b, d, id_pop): else: X = M_current.dot(b) s = d - b.T.dot(X) + # Ensure s is a scalar (extract from array if needed) + if np.ndim(s) > 0: + s = s.item() M = np.zeros((n, n)) M[:-1, :-1] = M_current + X.dot(X.T) / s X_ravel = X.ravel() diff --git a/ot/solvers.py b/ot/solvers.py index 68b389d63..c85f2dd76 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -8,7 +8,7 @@ # License: MIT License from .utils import OTResult, dist -from .lp import emd2, wasserstein_1d +from .lp import emd2, emd2_lazy, wasserstein_1d from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced from .bregman import ( @@ -1749,6 +1749,36 @@ def solve_sample( return res + elif ( + lazy + and method is None + and (reg is None or reg == 0) + and unbalanced is None + and X_a is not None + and X_b is not None + ): + # Use lazy EMD solver with coordinates (no regularization, balanced) + value_linear, log = emd2_lazy( + a, + b, + X_a, + X_b, + metric=metric, + numItermax=max_iter if max_iter is not None else 100000, + log=True, + return_matrix=True, + ) + + res = OTResult( + potentials=(log["u"], log["v"]), + value=value_linear, + value_linear=value_linear, + plan=log["G"], + status=log["warning"] if log["warning"] is not None else "Converged", + ) + + return res + else: # Detect backend nx = get_backend(X_a, X_b, a, b) diff --git a/ot/utils.py b/ot/utils.py index cc9de4f02..64bf1ace9 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -429,7 +429,9 @@ def dist( else: if isinstance(metric, str) and metric.endswith("minkowski"): return cdist(x1, x2, metric=metric, p=p, w=w) - if w is not None: + # Only pass w parameter for metrics that support it + # According to SciPy docs, only 'minkowski' and 'wminkowski' support w + if w is not None and metric in ["minkowski", "wminkowski"]: return cdist(x1, x2, metric=metric, w=w) return cdist(x1, x2, metric=metric) diff --git a/test/test_solvers.py b/test/test_solvers.py index 040b38dc6..548ce4a05 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -53,8 +53,8 @@ "method": "nystroem", "metric": "euclidean", }, # fail nystroem on metric not euclidean - {"lazy": True}, # fail lazy for non regularized - {"lazy": True, "unbalanced": 1}, # fail lazy for non regularized unbalanced + # Note: {"lazy": True} now works - lazy EMD solver implemented + {"lazy": True, "unbalanced": 1}, # fail lazy for unbalanced (not supported) { "lazy": True, "reg": 1, @@ -601,6 +601,79 @@ def test_solve_sample_lazy(nx): np.testing.assert_allclose(sol0.plan, sol.lazy_plan[:], rtol=1e-5, atol=1e-5) +@pytest.mark.parametrize("metric", ["sqeuclidean", "euclidean", "cityblock"]) +def test_solve_sample_lazy_emd(nx, metric): + # test lazy EMD solver (no regularization, computes distances on-the-fly) + n_s = 20 + n_t = 25 + d = 2 + rng = np.random.RandomState(42) + + X_s = rng.rand(n_s, d) + X_t = rng.rand(n_t, d) + a = ot.utils.unif(n_s) + b = ot.utils.unif(n_t) + + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + + # Standard solver: pre-compute distance matrix + M = ot.dist(X_sb, X_tb, metric=metric) + sol_standard = ot.solve(M, ab, bb) + + # Lazy solver: compute distances on-the-fly + sol_lazy = ot.solve_sample(X_sb, X_tb, ab, bb, lazy=True, metric=metric) + + # Check that results match + np.testing.assert_allclose( + nx.to_numpy(sol_standard.value), + nx.to_numpy(sol_lazy.value), + rtol=1e-10, + atol=1e-10, + err_msg=f"Lazy EMD cost mismatch for metric {metric}", + ) + + np.testing.assert_allclose( + nx.to_numpy(sol_standard.plan), + nx.to_numpy(sol_lazy.plan), + rtol=1e-10, + atol=1e-10, + err_msg=f"Lazy EMD plan mismatch for metric {metric}", + ) + + +def test_solve_sample_lazy_emd_large(nx): + # Test larger problem to verify memory savings benefit + n_large = 100 + d = 2 + rng = np.random.RandomState(42) + + X_s_large = rng.rand(n_large, d) + X_t_large = rng.rand(n_large, d) + a_large = ot.utils.unif(n_large) + b_large = ot.utils.unif(n_large) + + X_sb_large, X_tb_large, ab_large, bb_large = nx.from_numpy( + X_s_large, X_t_large, a_large, b_large + ) + + # Standard solver + M_large = ot.dist(X_sb_large, X_tb_large, metric="sqeuclidean") + sol_standard_large = ot.solve(M_large, ab_large, bb_large) + + # Lazy solver (avoids storing 100x100 cost matrix) + sol_lazy_large = ot.solve_sample( + X_sb_large, X_tb_large, ab_large, bb_large, lazy=True, metric="sqeuclidean" + ) + + np.testing.assert_allclose( + nx.to_numpy(sol_standard_large.value), + nx.to_numpy(sol_lazy_large.value), + rtol=1e-9, + atol=1e-9, + err_msg="Lazy EMD cost mismatch for large problem", + ) + + @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") @pytest.mark.skipif(not geomloss, reason="pytorch not installed") @pytest.skip_backend("tf") diff --git a/test/test_utils.py b/test/test_utils.py index 0b2769109..8c5e65b93 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -19,7 +19,7 @@ "correlation", ] -lst_all_metrics = lst_metrics + [ +lst_all_metrics_candidates = lst_metrics + [ "braycurtis", "canberra", "chebyshev", @@ -34,6 +34,18 @@ "yule", ] +# Filter to only include metrics available in current SciPy version +# (some metrics like sokalmichener were removed in newer SciPy versions) +lst_all_metrics = [] +for metric in lst_all_metrics_candidates: + try: + scipy.spatial.distance.cdist( + np.array([[0, 0]]), np.array([[1, 1]]), metric=metric + ) + lst_all_metrics.append(metric) + except ValueError: + pass + def get_LazyTensor(nx): n1 = 100 @@ -240,7 +252,18 @@ def test_dist(): "seuclidean", ] # do not support weights depending on scipy's version + # Filter out metrics not available in current scipy version + from scipy.spatial.distance import cdist + + available_metrics_w = [] for metric in metrics_w: + try: + cdist(x[:2], x[:2], metric=metric) + available_metrics_w.append(metric) + except ValueError: + pass + + for metric in available_metrics_w: print(metric) ot.dist(x, x, metric=metric, p=3, w=rng.random((2,))) ot.dist(