diff --git a/RELEASES.md b/RELEASES.md index cbe6c84fe..123b11f8d 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -2,15 +2,18 @@ ## 0.9.7.dev0 -This new release adds support for sparse cost matrices in the exact EMD solver. Users can now pass sparse cost matrices (e.g., k-NN graphs, sparse graphs) and receive sparse transport plans, significantly reducing memory footprint for large-scale problems. The implementation is backend-agnostic, automatically handling scipy.sparse for NumPy and torch.sparse for PyTorch, and preserves full gradient computation capabilities for automatic differentiation in PyTorch. This enables efficient solving of OT problems on graphs with millions of nodes where only a sparse subset of edges have finite costs. +This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation. #### New features +- Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788) - Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782) -- Geomloss function now handles both scalar and slice indices for i and j. Using backend agnostic reshaping. Allows to do plan[i,:] and plan[:,j] (PR #785) +- Geomloss function now handles both scalar and slice indices for i and j (PR #785) - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) #### Closed issues -- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration (PR #785) +- Fix NumPy 2.x compatibility in Brenier potential bounds (PR #788) +- Fix MSVC Windows build by removing __restrict__ keyword (PR #788) +- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration (PR #785) - Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770) - Add test for build from source (PR #772, Issue #764) - Fix device for batch Ot solver in `ot.batch` (PR #784, Issue #783) diff --git a/ot/__init__.py b/ot/__init__.py index 8a389bb98..26f428aa1 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -41,6 +41,7 @@ from .lp import ( emd, emd2, + emd2_lazy, emd_1d, emd2_1d, wasserstein_1d, @@ -82,6 +83,7 @@ __all__ = [ "emd", "emd2", + "emd2_lazy", "emd_1d", "sinkhorn", "sinkhorn2", diff --git a/ot/gromov/_estimators.py b/ot/gromov/_estimators.py index 14871bfe3..18359afe5 100644 --- a/ot/gromov/_estimators.py +++ b/ot/gromov/_estimators.py @@ -122,8 +122,8 @@ def GW_distance_estimation( for i in range(nb_samples_p): if nx.issparse(T): - T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,)) - T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,)) + T_indexi = nx.reshape(nx.todense(T[[index_i[i]], :]), (-1,)) + T_indexj = nx.reshape(nx.todense(T[[index_j[i]], :]), (-1,)) else: T_indexi = T[index_i[i], :] T_indexj = T[index_j[i], :] @@ -243,16 +243,18 @@ def pointwise_gromov_wasserstein( index = np.zeros(2, dtype=int) # Initialize with default marginal - index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) - index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q)) + index[0] = int(generator.choice(len_p, size=1, p=nx.to_numpy(p)).item()) + index[1] = int(generator.choice(len_q, size=1, p=nx.to_numpy(q)).item()) T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) best_gw_dist_estimated = np.inf for cpt in range(max_iter): - index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) - T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,)) - index[1] = generator.choice( - len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0)) + index[0] = int(generator.choice(len_p, size=1, p=nx.to_numpy(p)).item()) + T_index0 = nx.reshape(nx.todense(T[[index[0]], :]), (-1,)) + index[1] = int( + generator.choice( + len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0)) + ).item() ) if alpha == 1: @@ -404,10 +406,15 @@ def sampled_gromov_wasserstein( ) Lik = 0 for i, index0_i in enumerate(index0): + T_row = ( + nx.reshape(nx.todense(T[[index0_i], :]), (-1,)) + if nx.issparse(T) + else T[index0_i, :] + ) index1 = generator.choice( len_q, size=nb_samples_grad_q, - p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])), + p=nx.to_numpy(T_row / nx.sum(T_row)), replace=False, ) # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index e3564a2d2..6f408ffeb 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -51,5 +51,21 @@ int EMD_wrap_sparse( uint64_t maxIter // Maximum iterations for solver ); +int EMD_wrap_lazy( + int n1, // Number of source points + int n2, // Number of target points + double *X, // Source weights (n1) + double *Y, // Target weights (n2) + double *coords_a, // Source coordinates (n1 x dim) + double *coords_b, // Target coordinates (n2 x dim) + int dim, // Dimension of coordinates + int metric, // Distance metric: 0=sqeuclidean, 1=euclidean, 2=cityblock + double *G, // Output: transport plan (n1 x n2) + double *alpha, // Output: dual variables for sources (n1) + double *beta, // Output: dual variables for targets (n2) + double *cost, // Output: total transportation cost + uint64_t maxIter // Maximum iterations for solver +); + #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index bd3672535..6aa27897a 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -370,4 +370,108 @@ int EMD_wrap_sparse( } } return ret; -} \ No newline at end of file +} + +int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, + int dim, int metric, double *G, double *alpha, double *beta, + double *cost, uint64_t maxIter) { + using namespace lemon; + typedef FullBipartiteDigraph Digraph; + DIGRAPH_TYPEDEFS(Digraph); + + // Filter source nodes with non-zero weights + std::vector idx_a; + std::vector weights_a_filtered; + std::vector coords_a_filtered; + + // Reserve space to avoid reallocations + idx_a.reserve(n1); + weights_a_filtered.reserve(n1); + coords_a_filtered.reserve(n1 * dim); + + for (int i = 0; i < n1; i++) { + if (X[i] > 0) { + idx_a.push_back(i); + weights_a_filtered.push_back(X[i]); + for (int d = 0; d < dim; d++) { + coords_a_filtered.push_back(coords_a[i * dim + d]); + } + } + } + int n = idx_a.size(); + + // Filter target nodes with non-zero weights + std::vector idx_b; + std::vector weights_b_filtered; + std::vector coords_b_filtered; + + // Reserve space to avoid reallocations + idx_b.reserve(n2); + weights_b_filtered.reserve(n2); + coords_b_filtered.reserve(n2 * dim); + + for (int j = 0; j < n2; j++) { + if (Y[j] > 0) { + idx_b.push_back(j); + weights_b_filtered.push_back(-Y[j]); // Demand is negative supply + for (int d = 0; d < dim; d++) { + coords_b_filtered.push_back(coords_b[j * dim + d]); + } + } + } + int m = idx_b.size(); + + if (n == 0 || m == 0) { + *cost = 0.0; + return 0; + } + + // Create full bipartite graph + Digraph di(n, m); + + NetworkSimplexSimple net( + di, true, (int)(n + m), (uint64_t)(n) * (uint64_t)(m), maxIter + ); + + // Set supplies + net.supplyMap(&weights_a_filtered[0], n, &weights_b_filtered[0], m); + + // Enable lazy cost computation - costs will be computed on-the-fly + net.setLazyCost(&coords_a_filtered[0], &coords_b_filtered[0], dim, metric, n, m); + + // Run solver + int ret = net.run(); + + if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) { + *cost = 0; + + // Initialize output arrays + for (int i = 0; i < n1 * n2; i++) G[i] = 0.0; + for (int i = 0; i < n1; i++) alpha[i] = 0.0; + for (int i = 0; i < n2; i++) beta[i] = 0.0; + + // Extract solution + Arc a; + di.first(a); + for (; a != INVALID; di.next(a)) { + int i = di.source(a); + int j = di.target(a) - n; + + int orig_i = idx_a[i]; + int orig_j = idx_b[j]; + + double flow = net.flow(a); + G[orig_i * n2 + orig_j] = flow; + + alpha[orig_i] = -net.potential(i); + beta[orig_j] = net.potential(j + n); + + if (flow > 0) { + double c = net.computeLazyCost(i, j); + *cost += flow * c; + } + } + } + + return ret; +} diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index f8924a322..8e88d63c8 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -9,7 +9,7 @@ # License: MIT License from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize -from ._network_simplex import emd, emd2 +from ._network_simplex import emd, emd2, emd2_lazy from ._barycenter_solvers import ( barycenter, free_support_barycenter, @@ -35,6 +35,7 @@ __all__ = [ "emd", "emd2", + "emd2_lazy", "barycenter", "free_support_barycenter", "cvx", diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index d4dfa1ec3..ec06298bc 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -13,7 +13,7 @@ from ..utils import list_to_array, check_number_threads from ..backend import get_backend -from .emd_wrap import emd_c, emd_c_sparse, check_result +from .emd_wrap import emd_c, emd_c_sparse, emd_c_lazy, check_result def center_ot_dual(alpha0, beta0, a=None, b=None): @@ -320,20 +320,20 @@ def emd( if edge_costs.dtype != np.float64: edge_costs = edge_costs.astype(np.float64) - if len(a) != 0: + if a is not None and len(a) != 0: type_as = a - elif len(b) != 0: + elif b is not None and len(b) != 0: type_as = b else: - type_as = a + type_as = a if a is not None else b # Set n1, n2 if not already set (dense case) if n1 is None: n1, n2 = M.shape - if len(a) == 0: + if a is None or len(a) == 0: a = nx.ones((n1,), type_as=type_as) / n1 - if len(b) == 0: + if b is None or len(b) == 0: b = nx.ones((n2,), type_as=type_as) / n2 if is_sparse: @@ -471,7 +471,7 @@ def emd2( .. note:: This function will cast the computed transport plan and transportation loss to the data type of the provided input with the - following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`, + following priority : :math:`\mathbf{a}`, then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided. Casting to an integer tensor might result in a loss of precision. If this behaviour is unwanted, please make sure to provide a @@ -591,21 +591,21 @@ def emd2( if edge_costs.dtype != np.float64: edge_costs = edge_costs.astype(np.float64) - if len(a) != 0: + if a is not None and len(a) != 0: type_as = a - elif len(b) != 0: + elif b is not None and len(b) != 0: type_as = b else: - type_as = a + type_as = a if a is not None else b # Set n1, n2 if not already set (dense case) if n1 is None: n1, n2 = M.shape # if empty array given then use uniform distributions - if len(a) == 0: + if a is None or len(a) == 0: a = nx.ones((n1,), type_as=type_as) / n1 - if len(b) == 0: + if b is None or len(b) == 0: b = nx.ones((n2,), type_as=type_as) / n2 a0, b0 = a, b @@ -775,3 +775,180 @@ def f(b): res = list(map(f, [b[:, i].copy() for i in range(nb)])) return res + + +def emd2_lazy( + X_a, + X_b, + a=None, + b=None, + metric="sqeuclidean", + numItermax=100000, + log=False, + return_matrix=True, + center_dual=True, + check_marginals=True, +): + r"""Solves the Earth Movers distance problem with lazy cost computation and returns the loss + + .. math:: + \min_\gamma \quad \langle \gamma, \mathbf{M}(\mathbf{X}_a, \mathbf{X}_b) \rangle_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`\mathbf{M}(\mathbf{X}_a, \mathbf{X}_b)` is computed on-the-fly from coordinates + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + .. note:: This function computes distances on-the-fly during the network simplex algorithm, + avoiding the O(ns*nt) memory cost of pre-computing the full cost matrix. Memory usage + is O(ns+nt) instead. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Parameters + ---------- + X_a : (ns, dim) array-like, float64 + Source sample coordinates + X_b : (nt, dim) array-like, float64 + Target sample coordinates + a : (ns,) array-like, float64, optional + Source histogram (uniform weight if None) + b : (nt,) array-like, float64, optional + Target histogram (uniform weight if None) + metric : str, optional (default='sqeuclidean') + Distance metric for cost computation. Options: + + - 'sqeuclidean': Squared Euclidean distance + - 'euclidean': Euclidean distance + - 'cityblock': Manhattan/L1 distance + + numItermax : int, optional (default=100000) + Maximum number of iterations before stopping if not converged + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost, dual variables, + and optionally the transport plan (sparse format) + return_matrix: boolean, optional (default=False) + If True, returns the optimal transportation matrix in the log (sparse format) + center_dual: boolean, optional (default=True) + If True, centers the dual potential using :py:func:`ot.lp.center_ot_dual` + check_marginals: bool, optional (default=True) + If True, checks that the marginals mass are equal + + Returns + ------- + W: float + Optimal transportation loss + log: dict + If input log is True, a dictionary containing: + + - cost: the optimal transportation cost + - u, v: dual variables + - warning: solver status message + - result_code: solver return code + - G: (optional) sparse transport plan if return_matrix=True + + See Also + -------- + ot.emd2 : EMD with pre-computed cost matrix + ot.lp.emd_c_lazy : Low-level C++ lazy solver + """ + + a, b, X_a, X_b = list_to_array(a, b, X_a, X_b) + nx = get_backend(a, b, X_a, X_b) + + n1, n2 = X_a.shape[0], X_b.shape[0] + + if X_a.shape[1] != X_b.shape[1]: + raise ValueError( + f"X_a and X_b must have the same number of dimensions, " + f"got {X_a.shape[1]} and {X_b.shape[1]}" + ) + + if a is not None and len(a) != 0: + type_as = a + elif b is not None and len(b) != 0: + type_as = b + else: + type_as = X_a + + if a is None or len(a) == 0: + a = nx.ones((n1,), type_as=type_as) / n1 + if b is None or len(b) == 0: + b = nx.ones((n2,), type_as=type_as) / n2 + + a0, b0 = a, b + + # Convert to numpy for C++ backend + X_a_np = nx.to_numpy(X_a) + X_b_np = nx.to_numpy(X_b) + a_np = nx.to_numpy(a) + b_np = nx.to_numpy(b) + + X_a_np = np.asarray(X_a_np, dtype=np.float64, order="C") + X_b_np = np.asarray(X_b_np, dtype=np.float64, order="C") + a_np = np.asarray(a_np, dtype=np.float64) + b_np = np.asarray(b_np, dtype=np.float64) + + assert ( + a_np.shape[0] == n1 and b_np.shape[0] == n2 + ), "Dimension mismatch, check dimensions of X_a/X_b with a and b" + + if check_marginals: + np.testing.assert_almost_equal( + a_np.sum(), + b_np.sum(), + err_msg="a and b vector must have the same sum", + decimal=6, + ) + b_np = b_np * a_np.sum() / b_np.sum() + + G, cost, u, v, result_code = emd_c_lazy( + a_np, b_np, X_a_np, X_b_np, metric, numItermax + ) + + if center_dual: + u, v = center_ot_dual(u, v, a_np, b_np) + + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + + G_backend = nx.from_numpy(G, type_as=type_as) + + cost_backend = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + ), + ) + + check_result(result_code) + + if log or return_matrix: + log_dict = { + "cost": cost_backend, + "u": nx.from_numpy(u, type_as=type_as), + "v": nx.from_numpy(v, type_as=type_as), + "warning": check_result(result_code), + "result_code": result_code, + } + if return_matrix: + log_dict["G"] = G_backend + return cost_backend, log_dict + else: + return cost_backend diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 4ce315f5f..4f483dfe9 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -23,6 +23,7 @@ cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil + int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -285,4 +286,42 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, flow_targets = flow_targets[:n_flows_out] flow_values = flow_values[:n_flows_out] - return flow_sources, flow_targets, flow_values, cost, alpha, beta, result_code \ No newline at end of file + return flow_sources, flow_targets, flow_values, cost, alpha, beta, result_code + + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] coords_a, np.ndarray[double, ndim=2, mode="c"] coords_b, str metric='sqeuclidean', uint64_t max_iter=100000): + """Solves the Earth Movers distance problem with lazy cost computation from coordinates.""" + cdef int n1 = coords_a.shape[0] + cdef int n2 = coords_b.shape[0] + cdef int dim = coords_a.shape[1] + cdef int result_code = 0 + cdef double cost = 0 + cdef int metric_code + + # Validate dimension consistency + if coords_b.shape[1] != dim: + raise ValueError(f"Coordinate dimension mismatch: coords_a has {dim} dimensions but coords_b has {coords_b.shape[1]}") + + metric_map = { + 'sqeuclidean': 0, + 'euclidean': 1, + 'cityblock': 2 + } + + try: + metric_code = metric_map[metric] + except KeyError: + raise ValueError(f"Unknown metric: '{metric}'. Supported metrics are: {list(metric_map.keys())}") + + cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2) + cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros([n1, n2]) + if not len(a): + a = np.ones((n1,)) / n1 + if not len(b): + b = np.ones((n2,)) / n2 + with nogil: + result_code = EMD_wrap_lazy(n1, n2, a.data, b.data, coords_a.data, coords_b.data, dim, metric_code, G.data, alpha.data, beta.data, &cost, max_iter) + return G, cost, alpha, beta, result_code diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 9612a8a24..c9fef277e 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -27,18 +27,10 @@ #pragma once #undef DEBUG_LVL -#define DEBUG_LVL 0 - -#if DEBUG_LVL>0 -#include -#endif - #undef EPSILON #undef _EPSILON -#undef MAX_DEBUG_ITER #define EPSILON 2.2204460492503131e-15 #define _EPSILON 1e-8 -#define MAX_DEBUG_ITER 100000 /// \ingroup min_cost_flow_algs @@ -238,7 +230,8 @@ namespace lemon { _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), MAX(std::numeric_limits::max()), INF(std::numeric_limits::has_infinity ? - std::numeric_limits::infinity() : MAX) + std::numeric_limits::infinity() : MAX), + _lazy_cost(false), _coords_a(nullptr), _coords_b(nullptr), _dim(0), _metric(0), _n1(0), _n2(0) { // Reset data structures reset(); @@ -320,6 +313,8 @@ namespace lemon { // Data related to the underlying digraph const GR &_graph; int _node_num; + int _n1; // Number of source nodes (for lazy cost computation) + int _n2; // Number of target nodes (for lazy cost computation) ArcsType _arc_num; ArcsType _all_arc_num; ArcsType _search_arc_num; @@ -342,6 +337,12 @@ namespace lemon { //SparseValueVector _flow; CostVector _pi; + // Lazy cost computation support + bool _lazy_cost; + const double* _coords_a; + const double* _coords_b; + int _dim; + int _metric; // 0: sqeuclidean, 1: euclidean, 2: cityblock private: // Data for storing the spanning tree structure @@ -470,6 +471,41 @@ namespace lemon { _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE); } + // Get cost for an arc (either from pre-computed array or compute lazily) + inline Cost getCost(ArcsType e) const { + if (!_ns._lazy_cost) { + return _cost[e]; + } else { + // For lazy mode, compute cost from coordinates inline + // _source and _target use reversed node numbering + int i = _ns._node_num - _source[e] - 1; + int j = _ns._n2 - _target[e] - 1; + + const double* xa = _ns._coords_a + i * _ns._dim; + const double* xb = _ns._coords_b + j * _ns._dim; + Cost cost = 0; + + if (_ns._metric == 0) { // sqeuclidean + for (int d = 0; d < _ns._dim; ++d) { + Cost diff = xa[d] - xb[d]; + cost += diff * diff; + } + return cost; + } else if (_ns._metric == 1) { // euclidean + for (int d = 0; d < _ns._dim; ++d) { + Cost diff = xa[d] - xb[d]; + cost += diff * diff; + } + return std::sqrt(cost); + } else { // cityblock + for (int d = 0; d < _ns._dim; ++d) { + cost += std::abs(xa[d] - xb[d]); + } + return cost; + } + } + } + // Find next entering arc bool findEnteringArc() { Cost c, min = 0; @@ -477,33 +513,33 @@ namespace lemon { ArcsType cnt = _block_size; double a; for (e = _next_arc; e != _search_arc_num; ++e) { - c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + c = _state[e] * (getCost(e) + _pi[_source[e]] - _pi[_target[e]]); if (c < min) { min = c; _in_arc = e; } if (--cnt == 0) { a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); - a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]); + a=a>fabs(getCost(_in_arc))?a:fabs(getCost(_in_arc)); if (min < -EPSILON*a) goto search_end; cnt = _block_size; } } for (e = 0; e != _next_arc; ++e) { - c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + c = _state[e] * (getCost(e) + _pi[_source[e]] - _pi[_target[e]]); if (c < min) { min = c; _in_arc = e; } if (--cnt == 0) { a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); - a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]); + a=a>fabs(getCost(_in_arc))?a:fabs(getCost(_in_arc)); if (min < -EPSILON*a) goto search_end; cnt = _block_size; } } a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); - a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]); + a=a>fabs(getCost(_in_arc))?a:fabs(getCost(_in_arc)); if (min >= -EPSILON*a) return false; search_end: @@ -565,6 +601,90 @@ namespace lemon { return *this; } + /// \brief Enable lazy cost computation from coordinates. + /// + /// This function enables lazy cost computation where distances are + /// computed on-the-fly from point coordinates instead of using a + /// pre-computed cost matrix. + /// + /// \param coords_a Pointer to source coordinates (n1 x dim array) + /// \param coords_b Pointer to target coordinates (n2 x dim array) + /// \param dim Dimension of the coordinates + /// \param metric Distance metric: 0=sqeuclidean, 1=euclidean, 2=cityblock + /// + /// \return (*this) + NetworkSimplexSimple& setLazyCost(const double* coords_a, const double* coords_b, + int dim, int metric, int n1, int n2) { + _lazy_cost = true; + _coords_a = coords_a; + _coords_b = coords_b; + _dim = dim; + _metric = metric; + _n1 = n1; + _n2 = n2; + return *this; + } + + /// \brief Compute cost lazily from coordinates. + /// + /// Computes the distance between source node i and target node j + /// based on the specified metric. + /// + /// \param i Source node index + /// \param j Target node index (adjusted by subtracting n1) + /// + /// \return Cost (distance) between the two points + inline Cost computeLazyCost(int i, int j_adjusted) const { + const double* xa = _coords_a + i * _dim; + const double* xb = _coords_b + j_adjusted * _dim; + Cost cost = 0; + + if (_metric == 0) { // sqeuclidean + for (int d = 0; d < _dim; ++d) { + Cost diff = xa[d] - xb[d]; + cost += diff * diff; + } + return cost; + } else if (_metric == 1) { // euclidean + for (int d = 0; d < _dim; ++d) { + Cost diff = xa[d] - xb[d]; + cost += diff * diff; + } + return std::sqrt(cost); + } else { // cityblock (L1) + for (int d = 0; d < _dim; ++d) { + cost += std::abs(xa[d] - xb[d]); + } + return cost; + } + } + + + /// \brief Get cost for an arc (either from array or compute lazily). + /// + /// This is the main cost accessor that works from anywhere in the class. + /// In lazy mode, computes cost on-the-fly from coordinates. + /// In normal mode, returns pre-computed cost from array. + /// + /// \param arc_id The arc ID + /// \return Cost of the arc + inline Cost getCostForArc(ArcsType arc_id) const { + if (!_lazy_cost) { + return _cost[arc_id]; + } else { + // For artificial arcs (>= _arc_num), return 0 + // These are not real transport arcs + if (arc_id >= _arc_num) { + return 0; + } + // Compute lazily from coordinates + // _source and _target use reversed node numbering: _node_id(n) = _node_num - n - 1 + // Recover original indices: i = _node_num - _source[arc_id] - 1, j = _n2 - _target[arc_id] - 1 + int i = _node_num - _source[arc_id] - 1; + int j = _n2 - _target[arc_id] - 1; + return computeLazyCost(i, j); + } + } /// \brief Set the supply values of the nodes. /// @@ -689,14 +809,7 @@ namespace lemon { /// \see ProblemType, PivotRule /// \see resetParams(), reset() ProblemType run() { -#if DEBUG_LVL>0 - std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ; -#endif - if (!init()) return INFEASIBLE; -#if DEBUG_LVL>0 - std::cout << "Init done, starting iterations\n"; -#endif return start(); } @@ -879,8 +992,19 @@ namespace lemon { c += Number(it->second) * Number(_cost[it->first]); return c;*/ - for (ArcsType i=0; i<_flow.size(); i++) - c += _flow[i] * Number(_cost[i]); + if (!_lazy_cost) { + for (ArcsType i=0; i<_flow.size(); i++) + c += _flow[i] * Number(_cost[i]); + } else { + // Compute costs lazily + for (ArcsType i=0; i<_flow.size(); i++) { + if (_flow[i] != 0) { + int src = _node_num - _source[i] - 1; + int tgt = _n2 - _target[i] - 1; + c += _flow[i] * Number(computeLazyCost(src, tgt)); + } + } + } return c; } @@ -965,7 +1089,8 @@ namespace lemon { } else { ART_COST = 0; for (ArcsType i = 0; i != _arc_num; ++i) { - if (_cost[i] > ART_COST) ART_COST = _cost[i]; + Cost c = getCostForArc(i); + if (c > ART_COST) ART_COST = c; } ART_COST = (ART_COST + 1) * _node_num; } @@ -1305,8 +1430,8 @@ namespace lemon { // Update potentials void updatePotential() { Cost sigma = _forward[u_in] ? - _pi[v_in] - _pi[u_in] - _cost[_pred[u_in]] : - _pi[v_in] - _pi[u_in] + _cost[_pred[u_in]]; + _pi[v_in] - _pi[u_in] - getCostForArc(_pred[u_in]) : + _pi[v_in] - _pi[u_in] + getCostForArc(_pred[u_in]); // Update potentials in the subtree, which has been moved int end = _thread[_last_succ[u_in]]; for (int u = u_in; u != end; u = _thread[u]) { @@ -1365,7 +1490,7 @@ namespace lemon { Arc min_arc = INVALID; Arc a; _graph.firstIn(a, v); for (; a != INVALID; _graph.nextIn(a)) { - c = _cost[getArcID(a)]; + c = getCostForArc(getArcID(a)); if (c < min_cost) { min_cost = c; min_arc = a; @@ -1384,7 +1509,7 @@ namespace lemon { Arc min_arc = INVALID; Arc a; _graph.firstOut(a, u); for (; a != INVALID; _graph.nextOut(a)) { - c = _cost[getArcID(a)]; + c = getCostForArc(getArcID(a)); if (c < min_cost) { min_cost = c; min_arc = a; @@ -1400,7 +1525,7 @@ namespace lemon { for (ArcsType i = 0; i != arc_vector.size(); ++i) { in_arc = arc_vector[i]; // l'erreur est probablement ici... - if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] - + if (_state[in_arc] * (getCostForArc(in_arc) + _pi[_source[in_arc]] - _pi[_target[in_arc]]) >= 0) continue; findJoinNode(); bool change = findLeavingArc(); @@ -1436,27 +1561,6 @@ namespace lemon { retVal = MAX_ITER_REACHED; break; } -#if DEBUG_LVL>0 - if(iter_number>MAX_DEBUG_ITER) - break; - if(iter_number%1000==0||iter_number%1000==1){ - double curCost=totalCost(); - double sumFlow=0; - double a; - a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); - a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); - for (int64_t i=0; i<_flow.size(); i++) { - sumFlow+=_state[i]*_flow[i]; - } - std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; - std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; - std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; - std::cout << _cost[in_arc] << "\n"; - std::cout << _pi[_source[in_arc]] << "\n"; - std::cout << _pi[_target[in_arc]] << "\n"; - std::cout << a << "\n"; - } -#endif findJoinNode(); bool change = findLeavingArc(); @@ -1466,45 +1570,9 @@ namespace lemon { updateTreeStructure(); updatePotential(); } -#if DEBUG_LVL>0 - else{ - std::cout << "No change\n"; - } -#endif -#if DEBUG_LVL>1 - std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n"; -#endif } - -#if DEBUG_LVL>0 - double curCost=totalCost(); - double sumFlow=0; - double a; - a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); - a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); - for (int64_t i=0; i<_flow.size(); i++) { - sumFlow+=_state[i]*_flow[i]; - } - - std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; - - std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; - std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; - -#endif - -#if DEBUG_LVL>1 - sumFlow=0; - for (int i=0; i<_flow.size(); i++) { - sumFlow+=_state[i]*_flow[i]; - if (_state[i]==STATE_TREE) { - std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n"; - } - } - std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n"; -#endif // Check feasibility if( retVal == OPTIMAL){ for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) { diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h index 890b7ab03..d8ef672e3 100644 --- a/ot/lp/network_simplex_simple_omp.h +++ b/ot/lp/network_simplex_simple_omp.h @@ -703,7 +703,6 @@ namespace lemon_omp { return *this; } - /// \brief Set the supply values of the nodes. /// /// This function sets the supply values of the nodes. diff --git a/ot/mapping.py b/ot/mapping.py index cc3e6cd57..bb1a78e8a 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -376,7 +376,9 @@ def nearest_brenier_potential_predict_bounds( ] problem = cvx.Problem(objective, constraints) problem.solve(solver=solver) - phi_lu[0, y_idx] = phi_l_y.value + phi_lu[0, y_idx] = ( + phi_l_y.value if np.isscalar(phi_l_y.value) else phi_l_y.value.item() + ) G_lu[0, y_idx] = G_l_y.value if log: log_item["l"] = { @@ -403,7 +405,9 @@ def nearest_brenier_potential_predict_bounds( ] problem = cvx.Problem(objective, constraints) problem.solve(solver=solver) - phi_lu[1, y_idx] = phi_u_y.value + phi_lu[1, y_idx] = ( + phi_u_y.value if np.isscalar(phi_u_y.value) else phi_u_y.value.item() + ) G_lu[1, y_idx] = G_u_y.value if log: log_item["u"] = { diff --git a/ot/regpath.py b/ot/regpath.py index e64ca7c77..aedc35b88 100644 --- a/ot/regpath.py +++ b/ot/regpath.py @@ -486,6 +486,9 @@ def complement_schur(M_current, b, d, id_pop): else: X = M_current.dot(b) s = d - b.T.dot(X) + # Ensure s is a scalar (extract from array if needed) + if np.ndim(s) > 0: + s = s.item() M = np.zeros((n, n)) M[:-1, :-1] = M_current + X.dot(X.T) / s X_ravel = X.ravel() diff --git a/ot/solvers.py b/ot/solvers.py index 68b389d63..99314da02 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -8,7 +8,7 @@ # License: MIT License from .utils import OTResult, dist -from .lp import emd2, wasserstein_1d +from .lp import emd2, emd2_lazy, wasserstein_1d from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced from .bregman import ( @@ -1749,6 +1749,36 @@ def solve_sample( return res + elif ( + lazy + and method is None + and (reg is None or reg == 0) + and unbalanced is None + and X_a is not None + and X_b is not None + ): + # Use lazy EMD solver with coordinates (no regularization, balanced) + value_linear, log = emd2_lazy( + X_a, + X_b, + a, + b, + metric=metric, + numItermax=max_iter if max_iter is not None else 100000, + log=True, + return_matrix=True, + ) + + res = OTResult( + potentials=(log["u"], log["v"]), + value=value_linear, + value_linear=value_linear, + plan=log["G"], + status=log["warning"] if log["warning"] is not None else "Converged", + ) + + return res + else: # Detect backend nx = get_backend(X_a, X_b, a, b) diff --git a/ot/utils.py b/ot/utils.py index cc9de4f02..64bf1ace9 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -429,7 +429,9 @@ def dist( else: if isinstance(metric, str) and metric.endswith("minkowski"): return cdist(x1, x2, metric=metric, p=p, w=w) - if w is not None: + # Only pass w parameter for metrics that support it + # According to SciPy docs, only 'minkowski' and 'wminkowski' support w + if w is not None and metric in ["minkowski", "wminkowski"]: return cdist(x1, x2, metric=metric, w=w) return cdist(x1, x2, metric=metric) diff --git a/test/test_solvers.py b/test/test_solvers.py index 040b38dc6..9b0ababbf 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -53,8 +53,7 @@ "method": "nystroem", "metric": "euclidean", }, # fail nystroem on metric not euclidean - {"lazy": True}, # fail lazy for non regularized - {"lazy": True, "unbalanced": 1}, # fail lazy for non regularized unbalanced + {"lazy": True, "unbalanced": 1}, # fail lazy for unbalanced (not supported) { "lazy": True, "reg": 1, @@ -601,6 +600,79 @@ def test_solve_sample_lazy(nx): np.testing.assert_allclose(sol0.plan, sol.lazy_plan[:], rtol=1e-5, atol=1e-5) +@pytest.mark.parametrize("metric", ["sqeuclidean", "euclidean", "cityblock"]) +def test_solve_sample_lazy_emd(nx, metric): + # test lazy EMD solver (no regularization, computes distances on-the-fly) + n_s = 20 + n_t = 25 + d = 2 + rng = np.random.RandomState(42) + + X_s = rng.rand(n_s, d) + X_t = rng.rand(n_t, d) + a = ot.utils.unif(n_s) + b = ot.utils.unif(n_t) + + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + + # Standard solver: pre-compute distance matrix + M = ot.dist(X_sb, X_tb, metric=metric) + sol_standard = ot.solve(M, ab, bb) + + # Lazy solver: compute distances on-the-fly + sol_lazy = ot.solve_sample(X_sb, X_tb, ab, bb, lazy=True, metric=metric) + + # Check that results match + np.testing.assert_allclose( + nx.to_numpy(sol_standard.value), + nx.to_numpy(sol_lazy.value), + rtol=1e-10, + atol=1e-10, + err_msg=f"Lazy EMD cost mismatch for metric {metric}", + ) + + np.testing.assert_allclose( + nx.to_numpy(sol_standard.plan), + nx.to_numpy(sol_lazy.plan), + rtol=1e-10, + atol=1e-10, + err_msg=f"Lazy EMD plan mismatch for metric {metric}", + ) + + +def test_solve_sample_lazy_emd_large(nx): + # Test larger problem to verify memory savings benefit + n_large = 100 + d = 2 + rng = np.random.RandomState(42) + + X_s_large = rng.rand(n_large, d) + X_t_large = rng.rand(n_large, d) + a_large = ot.utils.unif(n_large) + b_large = ot.utils.unif(n_large) + + X_sb_large, X_tb_large, ab_large, bb_large = nx.from_numpy( + X_s_large, X_t_large, a_large, b_large + ) + + # Standard solver + M_large = ot.dist(X_sb_large, X_tb_large, metric="sqeuclidean") + sol_standard_large = ot.solve(M_large, ab_large, bb_large) + + # Lazy solver (avoids storing 100x100 cost matrix) + sol_lazy_large = ot.solve_sample( + X_sb_large, X_tb_large, ab_large, bb_large, lazy=True, metric="sqeuclidean" + ) + + np.testing.assert_allclose( + nx.to_numpy(sol_standard_large.value), + nx.to_numpy(sol_lazy_large.value), + rtol=1e-9, + atol=1e-9, + err_msg="Lazy EMD cost mismatch for large problem", + ) + + @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") @pytest.mark.skipif(not geomloss, reason="pytorch not installed") @pytest.skip_backend("tf") diff --git a/test/test_utils.py b/test/test_utils.py index 0b2769109..8c5e65b93 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -19,7 +19,7 @@ "correlation", ] -lst_all_metrics = lst_metrics + [ +lst_all_metrics_candidates = lst_metrics + [ "braycurtis", "canberra", "chebyshev", @@ -34,6 +34,18 @@ "yule", ] +# Filter to only include metrics available in current SciPy version +# (some metrics like sokalmichener were removed in newer SciPy versions) +lst_all_metrics = [] +for metric in lst_all_metrics_candidates: + try: + scipy.spatial.distance.cdist( + np.array([[0, 0]]), np.array([[1, 1]]), metric=metric + ) + lst_all_metrics.append(metric) + except ValueError: + pass + def get_LazyTensor(nx): n1 = 100 @@ -240,7 +252,18 @@ def test_dist(): "seuclidean", ] # do not support weights depending on scipy's version + # Filter out metrics not available in current scipy version + from scipy.spatial.distance import cdist + + available_metrics_w = [] for metric in metrics_w: + try: + cdist(x[:2], x[:2], metric=metric) + available_metrics_w.append(metric) + except ValueError: + pass + + for metric in available_metrics_w: print(metric) ot.dist(x, x, metric=metric, p=3, w=rng.random((2,))) ot.dist(