From fb43b72876b0eb6d0860a8efeb0001e03a5765a8 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Mon, 19 Jan 2026 17:43:45 +0100 Subject: [PATCH 1/8] Add lazy EMD solver with on-the-fly distance computation - Implement emd_c_lazy in C++ network simplex for memory-efficient OT - Add lazy mode to emd2() accepting coordinates (X_a, X_b) instead of cost matrix - Support sqeuclidean, euclidean, and cityblock metrics - Add __restrict__ for SIMD optimization - Remove debug output from network_simplex_simple.h - Add tests for lazy solver and metric variants --- ot/lp/EMD.h | 16 +++ ot/lp/EMD_wrapper.cpp | 106 ++++++++++++++- ot/lp/_network_simplex.py | 90 ++++++++++--- ot/lp/emd_wrap.pyx | 38 +++++- ot/lp/network_simplex_simple.h | 240 +++++++++++++++++++++------------ ot/solvers.py | 32 +++++ test/test_solvers.py | 73 +++++++++- 7 files changed, 487 insertions(+), 108 deletions(-) diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index e3564a2d2..6f408ffeb 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -51,5 +51,21 @@ int EMD_wrap_sparse( uint64_t maxIter // Maximum iterations for solver ); +int EMD_wrap_lazy( + int n1, // Number of source points + int n2, // Number of target points + double *X, // Source weights (n1) + double *Y, // Target weights (n2) + double *coords_a, // Source coordinates (n1 x dim) + double *coords_b, // Target coordinates (n2 x dim) + int dim, // Dimension of coordinates + int metric, // Distance metric: 0=sqeuclidean, 1=euclidean, 2=cityblock + double *G, // Output: transport plan (n1 x n2) + double *alpha, // Output: dual variables for sources (n1) + double *beta, // Output: dual variables for targets (n2) + double *cost, // Output: total transportation cost + uint64_t maxIter // Maximum iterations for solver +); + #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index bd3672535..6aa27897a 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -370,4 +370,108 @@ int EMD_wrap_sparse( } } return ret; -} \ No newline at end of file +} + +int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, + int dim, int metric, double *G, double *alpha, double *beta, + double *cost, uint64_t maxIter) { + using namespace lemon; + typedef FullBipartiteDigraph Digraph; + DIGRAPH_TYPEDEFS(Digraph); + + // Filter source nodes with non-zero weights + std::vector idx_a; + std::vector weights_a_filtered; + std::vector coords_a_filtered; + + // Reserve space to avoid reallocations + idx_a.reserve(n1); + weights_a_filtered.reserve(n1); + coords_a_filtered.reserve(n1 * dim); + + for (int i = 0; i < n1; i++) { + if (X[i] > 0) { + idx_a.push_back(i); + weights_a_filtered.push_back(X[i]); + for (int d = 0; d < dim; d++) { + coords_a_filtered.push_back(coords_a[i * dim + d]); + } + } + } + int n = idx_a.size(); + + // Filter target nodes with non-zero weights + std::vector idx_b; + std::vector weights_b_filtered; + std::vector coords_b_filtered; + + // Reserve space to avoid reallocations + idx_b.reserve(n2); + weights_b_filtered.reserve(n2); + coords_b_filtered.reserve(n2 * dim); + + for (int j = 0; j < n2; j++) { + if (Y[j] > 0) { + idx_b.push_back(j); + weights_b_filtered.push_back(-Y[j]); // Demand is negative supply + for (int d = 0; d < dim; d++) { + coords_b_filtered.push_back(coords_b[j * dim + d]); + } + } + } + int m = idx_b.size(); + + if (n == 0 || m == 0) { + *cost = 0.0; + return 0; + } + + // Create full bipartite graph + Digraph di(n, m); + + NetworkSimplexSimple net( + di, true, (int)(n + m), (uint64_t)(n) * (uint64_t)(m), maxIter + ); + + // Set supplies + net.supplyMap(&weights_a_filtered[0], n, &weights_b_filtered[0], m); + + // Enable lazy cost computation - costs will be computed on-the-fly + net.setLazyCost(&coords_a_filtered[0], &coords_b_filtered[0], dim, metric, n, m); + + // Run solver + int ret = net.run(); + + if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) { + *cost = 0; + + // Initialize output arrays + for (int i = 0; i < n1 * n2; i++) G[i] = 0.0; + for (int i = 0; i < n1; i++) alpha[i] = 0.0; + for (int i = 0; i < n2; i++) beta[i] = 0.0; + + // Extract solution + Arc a; + di.first(a); + for (; a != INVALID; di.next(a)) { + int i = di.source(a); + int j = di.target(a) - n; + + int orig_i = idx_a[i]; + int orig_j = idx_b[j]; + + double flow = net.flow(a); + G[orig_i * n2 + orig_j] = flow; + + alpha[orig_i] = -net.potential(i); + beta[orig_j] = net.potential(j + n); + + if (flow > 0) { + double c = net.computeLazyCost(i, j); + *cost += flow * c; + } + } + } + + return ret; +} diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index d4dfa1ec3..9a4969a89 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -13,7 +13,7 @@ from ..utils import list_to_array, check_number_threads from ..backend import get_backend -from .emd_wrap import emd_c, emd_c_sparse, check_result +from .emd_wrap import emd_c, emd_c_sparse, emd_c_lazy, check_result def center_ot_dual(alpha0, beta0, a=None, b=None): @@ -320,20 +320,20 @@ def emd( if edge_costs.dtype != np.float64: edge_costs = edge_costs.astype(np.float64) - if len(a) != 0: + if a is not None and len(a) != 0: type_as = a - elif len(b) != 0: + elif b is not None and len(b) != 0: type_as = b else: - type_as = a + type_as = a if a is not None else b # Set n1, n2 if not already set (dense case) if n1 is None: n1, n2 = M.shape - if len(a) == 0: + if a is None or len(a) == 0: a = nx.ones((n1,), type_as=type_as) / n1 - if len(b) == 0: + if b is None or len(b) == 0: b = nx.ones((n2,), type_as=type_as) / n2 if is_sparse: @@ -440,7 +440,10 @@ def emd( def emd2( a, b, - M, + M=None, + X_a=None, + X_b=None, + metric="sqeuclidean", processes=1, numItermax=100000, log=False, @@ -487,13 +490,23 @@ def emd2( Source histogram (uniform weight if empty list) b : (nt,) array-like, float64 Target histogram (uniform weight if empty list) - M : (ns,nt) array-like or sparse matrix, float64 + M : (ns,nt) array-like or sparse matrix, float64, optional Loss matrix. Can be: - Dense array (c-order array in numpy with type float64) - Sparse matrix in backend's format (scipy.sparse.coo_matrix for NumPy backend, torch.sparse_coo_tensor for PyTorch backend, etc.) + Either M or (X_a, X_b) must be provided. + X_a : (ns, dim) array-like, float64, optional + Source coordinates for lazy cost computation. + If provided along with X_b, costs will be computed on-the-fly. + X_b : (nt, dim) array-like, float64, optional + Target coordinates for lazy cost computation. + If provided along with X_a, costs will be computed on-the-fly. + metric : str, optional (default='sqeuclidean') + Distance metric for lazy mode. Options: 'sqeuclidean', 'euclidean', 'cityblock'. + Only used when X_a and X_b are provided. processes : int, optional (default=1) Nb of processes used for multiple emd computation (deprecated) numItermax : int, optional (default=100000) @@ -564,11 +577,25 @@ def emd2( n1, n2 = None, None - a, b, M = list_to_array(a, b, M) - nx = get_backend(a, b, M) + # Check if we're using lazy mode with coordinates + use_lazy = X_a is not None and X_b is not None and M is None - # Check if M is sparse using backend's issparse method - is_sparse = nx.issparse(M) + if use_lazy: + # Lazy mode: coordinates provided + a, b, X_a, X_b = list_to_array(a, b, X_a, X_b) + nx = get_backend(a, b, X_a, X_b) + n1, n2 = X_a.shape[0], X_b.shape[0] + is_sparse = False + else: + # Standard mode: cost matrix provided + if M is None: + raise ValueError( + "Either M (cost matrix) or (X_a, X_b) coordinates must be provided" + ) + a, b, M = list_to_array(a, b, M) + nx = get_backend(a, b, M) + # Check if M is sparse using backend's issparse method + is_sparse = nx.issparse(M) # Save original sparse tensor for gradient tracking (before conversion to numpy) M_original_sparse = None @@ -591,21 +618,24 @@ def emd2( if edge_costs.dtype != np.float64: edge_costs = edge_costs.astype(np.float64) - if len(a) != 0: + if a is not None and len(a) != 0: type_as = a - elif len(b) != 0: + elif b is not None and len(b) != 0: type_as = b + elif use_lazy: + # In lazy mode, use coordinates for type inference + type_as = X_a else: - type_as = a + type_as = a if a is not None else b # Set n1, n2 if not already set (dense case) if n1 is None: n1, n2 = M.shape # if empty array given then use uniform distributions - if len(a) == 0: + if a is None or len(a) == 0: a = nx.ones((n1,), type_as=type_as) / n1 - if len(b) == 0: + if b is None or len(b) == 0: b = nx.ones((n2,), type_as=type_as) / n2 a0, b0 = a, b @@ -650,7 +680,10 @@ def emd2( def f(b): bsel = b != 0 - if is_sparse: + if use_lazy: + # Solve with lazy cost computation from coordinates + G, cost, u, v, result_code = emd_c_lazy(a, b, X_a, X_b, metric, numItermax) + elif is_sparse: # Solve sparse EMD flow_sources, flow_targets, flow_values, cost, u, v, result_code = ( emd_c_sparse(a, b, edge_sources, edge_targets, edge_costs, numItermax) @@ -724,6 +757,27 @@ def f(b): shape=(n1, n2), type_as=type_as, ) + elif use_lazy: + # Lazy case: warn about integer casting + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + + G_backend = nx.from_numpy(G, type_as=type_as) + # For now, just set gradients wrt marginals + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + ), + ) else: # Dense case: warn about integer casting if not nx.is_floating_point(type_as): diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 4ce315f5f..2b2a5a1ef 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -23,6 +23,7 @@ cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil + int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -285,4 +286,39 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, flow_targets = flow_targets[:n_flows_out] flow_values = flow_values[:n_flows_out] - return flow_sources, flow_targets, flow_values, cost, alpha, beta, result_code \ No newline at end of file + return flow_sources, flow_targets, flow_values, cost, alpha, beta, result_code + + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] coords_a, np.ndarray[double, ndim=2, mode="c"] coords_b, str metric='sqeuclidean', uint64_t max_iter=100000): + """Solves the Earth Movers distance problem with lazy cost computation from coordinates.""" + cdef int n1 = coords_a.shape[0] + cdef int n2 = coords_b.shape[0] + cdef int dim = coords_a.shape[1] + cdef int result_code = 0 + cdef double cost = 0 + cdef int metric_code + + # Validate dimension consistency + if coords_b.shape[1] != dim: + raise ValueError(f"Coordinate dimension mismatch: coords_a has {dim} dimensions but coords_b has {coords_b.shape[1]}") + + if metric == 'sqeuclidean': + metric_code = 0 + elif metric == 'euclidean': + metric_code = 1 + elif metric == 'cityblock': + metric_code = 2 + else: + raise ValueError(f"Unknown metric: {metric}") + cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2) + cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros([n1, n2]) + if not len(a): + a = np.ones((n1,)) / n1 + if not len(b): + b = np.ones((n2,)) / n2 + with nogil: + result_code = EMD_wrap_lazy(n1, n2, a.data, b.data, coords_a.data, coords_b.data, dim, metric_code, G.data, alpha.data, beta.data, &cost, max_iter) + return G, cost, alpha, beta, result_code diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 9612a8a24..1566e74e0 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -27,18 +27,10 @@ #pragma once #undef DEBUG_LVL -#define DEBUG_LVL 0 - -#if DEBUG_LVL>0 -#include -#endif - #undef EPSILON #undef _EPSILON -#undef MAX_DEBUG_ITER #define EPSILON 2.2204460492503131e-15 #define _EPSILON 1e-8 -#define MAX_DEBUG_ITER 100000 /// \ingroup min_cost_flow_algs @@ -238,7 +230,8 @@ namespace lemon { _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), MAX(std::numeric_limits::max()), INF(std::numeric_limits::has_infinity ? - std::numeric_limits::infinity() : MAX) + std::numeric_limits::infinity() : MAX), + _lazy_cost(false), _coords_a(nullptr), _coords_b(nullptr), _dim(0), _metric(0), _n1(0), _n2(0) { // Reset data structures reset(); @@ -320,6 +313,8 @@ namespace lemon { // Data related to the underlying digraph const GR &_graph; int _node_num; + int _n1; // Number of source nodes (for lazy cost computation) + int _n2; // Number of target nodes (for lazy cost computation) ArcsType _arc_num; ArcsType _all_arc_num; ArcsType _search_arc_num; @@ -342,6 +337,12 @@ namespace lemon { //SparseValueVector _flow; CostVector _pi; + // Lazy cost computation support + bool _lazy_cost; + const double* _coords_a; + const double* _coords_b; + int _dim; + int _metric; // 0: sqeuclidean, 1: euclidean, 2: cityblock private: // Data for storing the spanning tree structure @@ -470,6 +471,41 @@ namespace lemon { _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE); } + // Get cost for an arc (either from pre-computed array or compute lazily) + inline Cost getCost(ArcsType e) const { + if (!_ns._lazy_cost) { + return _cost[e]; + } else { + // For lazy mode, compute cost from coordinates inline + // _source and _target use reversed node numbering + int i = _ns._node_num - _source[e] - 1; + int j = _ns._n2 - _target[e] - 1; + + const double* __restrict__ xa = _ns._coords_a + i * _ns._dim; + const double* __restrict__ xb = _ns._coords_b + j * _ns._dim; + Cost cost = 0; + + if (_ns._metric == 0) { // sqeuclidean + for (int d = 0; d < _ns._dim; ++d) { + Cost diff = xa[d] - xb[d]; + cost += diff * diff; + } + return cost; + } else if (_ns._metric == 1) { // euclidean + for (int d = 0; d < _ns._dim; ++d) { + Cost diff = xa[d] - xb[d]; + cost += diff * diff; + } + return std::sqrt(cost); + } else { // cityblock + for (int d = 0; d < _ns._dim; ++d) { + cost += std::abs(xa[d] - xb[d]); + } + return cost; + } + } + } + // Find next entering arc bool findEnteringArc() { Cost c, min = 0; @@ -477,33 +513,33 @@ namespace lemon { ArcsType cnt = _block_size; double a; for (e = _next_arc; e != _search_arc_num; ++e) { - c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + c = _state[e] * (getCost(e) + _pi[_source[e]] - _pi[_target[e]]); if (c < min) { min = c; _in_arc = e; } if (--cnt == 0) { a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); - a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]); + a=a>fabs(getCost(_in_arc))?a:fabs(getCost(_in_arc)); if (min < -EPSILON*a) goto search_end; cnt = _block_size; } } for (e = 0; e != _next_arc; ++e) { - c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + c = _state[e] * (getCost(e) + _pi[_source[e]] - _pi[_target[e]]); if (c < min) { min = c; _in_arc = e; } if (--cnt == 0) { a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); - a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]); + a=a>fabs(getCost(_in_arc))?a:fabs(getCost(_in_arc)); if (min < -EPSILON*a) goto search_end; cnt = _block_size; } } a=fabs(_pi[_source[_in_arc]])>fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]):fabs(_pi[_target[_in_arc]]); - a=a>fabs(_cost[_in_arc])?a:fabs(_cost[_in_arc]); + a=a>fabs(getCost(_in_arc))?a:fabs(getCost(_in_arc)); if (min >= -EPSILON*a) return false; search_end: @@ -565,6 +601,90 @@ namespace lemon { return *this; } + /// \brief Enable lazy cost computation from coordinates. + /// + /// This function enables lazy cost computation where distances are + /// computed on-the-fly from point coordinates instead of using a + /// pre-computed cost matrix. + /// + /// \param coords_a Pointer to source coordinates (n1 x dim array) + /// \param coords_b Pointer to target coordinates (n2 x dim array) + /// \param dim Dimension of the coordinates + /// \param metric Distance metric: 0=sqeuclidean, 1=euclidean, 2=cityblock + /// + /// \return (*this) + NetworkSimplexSimple& setLazyCost(const double* coords_a, const double* coords_b, + int dim, int metric, int n1, int n2) { + _lazy_cost = true; + _coords_a = coords_a; + _coords_b = coords_b; + _dim = dim; + _metric = metric; + _n1 = n1; + _n2 = n2; + return *this; + } + + /// \brief Compute cost lazily from coordinates. + /// + /// Computes the distance between source node i and target node j + /// based on the specified metric. + /// + /// \param i Source node index + /// \param j Target node index (adjusted by subtracting n1) + /// + /// \return Cost (distance) between the two points + inline Cost computeLazyCost(int i, int j_adjusted) const { + const double* __restrict__ xa = _coords_a + i * _dim; + const double* __restrict__ xb = _coords_b + j_adjusted * _dim; + Cost cost = 0; + + if (_metric == 0) { // sqeuclidean + for (int d = 0; d < _dim; ++d) { + Cost diff = xa[d] - xb[d]; + cost += diff * diff; + } + return cost; + } else if (_metric == 1) { // euclidean + for (int d = 0; d < _dim; ++d) { + Cost diff = xa[d] - xb[d]; + cost += diff * diff; + } + return std::sqrt(cost); + } else { // cityblock (L1) + for (int d = 0; d < _dim; ++d) { + cost += std::abs(xa[d] - xb[d]); + } + return cost; + } + } + + + /// \brief Get cost for an arc (either from array or compute lazily). + /// + /// This is the main cost accessor that works from anywhere in the class. + /// In lazy mode, computes cost on-the-fly from coordinates. + /// In normal mode, returns pre-computed cost from array. + /// + /// \param arc_id The arc ID + /// \return Cost of the arc + inline Cost getCostForArc(ArcsType arc_id) const { + if (!_lazy_cost) { + return _cost[arc_id]; + } else { + // For artificial arcs (>= _arc_num), return 0 + // These are not real transport arcs + if (arc_id >= _arc_num) { + return 0; + } + // Compute lazily from coordinates + // _source and _target use reversed node numbering: _node_id(n) = _node_num - n - 1 + // Recover original indices: i = _node_num - _source[arc_id] - 1, j = _n2 - _target[arc_id] - 1 + int i = _node_num - _source[arc_id] - 1; + int j = _n2 - _target[arc_id] - 1; + return computeLazyCost(i, j); + } + } /// \brief Set the supply values of the nodes. /// @@ -689,14 +809,7 @@ namespace lemon { /// \see ProblemType, PivotRule /// \see resetParams(), reset() ProblemType run() { -#if DEBUG_LVL>0 - std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ; -#endif - if (!init()) return INFEASIBLE; -#if DEBUG_LVL>0 - std::cout << "Init done, starting iterations\n"; -#endif return start(); } @@ -879,8 +992,19 @@ namespace lemon { c += Number(it->second) * Number(_cost[it->first]); return c;*/ - for (ArcsType i=0; i<_flow.size(); i++) - c += _flow[i] * Number(_cost[i]); + if (!_lazy_cost) { + for (ArcsType i=0; i<_flow.size(); i++) + c += _flow[i] * Number(_cost[i]); + } else { + // Compute costs lazily + for (ArcsType i=0; i<_flow.size(); i++) { + if (_flow[i] != 0) { + int src = _node_num - _source[i] - 1; + int tgt = _n2 - _target[i] - 1; + c += _flow[i] * Number(computeLazyCost(src, tgt)); + } + } + } return c; } @@ -965,7 +1089,8 @@ namespace lemon { } else { ART_COST = 0; for (ArcsType i = 0; i != _arc_num; ++i) { - if (_cost[i] > ART_COST) ART_COST = _cost[i]; + Cost c = getCostForArc(i); + if (c > ART_COST) ART_COST = c; } ART_COST = (ART_COST + 1) * _node_num; } @@ -1305,8 +1430,8 @@ namespace lemon { // Update potentials void updatePotential() { Cost sigma = _forward[u_in] ? - _pi[v_in] - _pi[u_in] - _cost[_pred[u_in]] : - _pi[v_in] - _pi[u_in] + _cost[_pred[u_in]]; + _pi[v_in] - _pi[u_in] - getCostForArc(_pred[u_in]) : + _pi[v_in] - _pi[u_in] + getCostForArc(_pred[u_in]); // Update potentials in the subtree, which has been moved int end = _thread[_last_succ[u_in]]; for (int u = u_in; u != end; u = _thread[u]) { @@ -1365,7 +1490,7 @@ namespace lemon { Arc min_arc = INVALID; Arc a; _graph.firstIn(a, v); for (; a != INVALID; _graph.nextIn(a)) { - c = _cost[getArcID(a)]; + c = getCostForArc(getArcID(a)); if (c < min_cost) { min_cost = c; min_arc = a; @@ -1384,7 +1509,7 @@ namespace lemon { Arc min_arc = INVALID; Arc a; _graph.firstOut(a, u); for (; a != INVALID; _graph.nextOut(a)) { - c = _cost[getArcID(a)]; + c = getCostForArc(getArcID(a)); if (c < min_cost) { min_cost = c; min_arc = a; @@ -1400,7 +1525,7 @@ namespace lemon { for (ArcsType i = 0; i != arc_vector.size(); ++i) { in_arc = arc_vector[i]; // l'erreur est probablement ici... - if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] - + if (_state[in_arc] * (getCostForArc(in_arc) + _pi[_source[in_arc]] - _pi[_target[in_arc]]) >= 0) continue; findJoinNode(); bool change = findLeavingArc(); @@ -1436,27 +1561,6 @@ namespace lemon { retVal = MAX_ITER_REACHED; break; } -#if DEBUG_LVL>0 - if(iter_number>MAX_DEBUG_ITER) - break; - if(iter_number%1000==0||iter_number%1000==1){ - double curCost=totalCost(); - double sumFlow=0; - double a; - a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); - a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); - for (int64_t i=0; i<_flow.size(); i++) { - sumFlow+=_state[i]*_flow[i]; - } - std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; - std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; - std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; - std::cout << _cost[in_arc] << "\n"; - std::cout << _pi[_source[in_arc]] << "\n"; - std::cout << _pi[_target[in_arc]] << "\n"; - std::cout << a << "\n"; - } -#endif findJoinNode(); bool change = findLeavingArc(); @@ -1466,45 +1570,9 @@ namespace lemon { updateTreeStructure(); updatePotential(); } -#if DEBUG_LVL>0 - else{ - std::cout << "No change\n"; - } -#endif -#if DEBUG_LVL>1 - std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n"; -#endif } - -#if DEBUG_LVL>0 - double curCost=totalCost(); - double sumFlow=0; - double a; - a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); - a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); - for (int64_t i=0; i<_flow.size(); i++) { - sumFlow+=_state[i]*_flow[i]; - } - - std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; - - std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; - std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; - -#endif - -#if DEBUG_LVL>1 - sumFlow=0; - for (int i=0; i<_flow.size(); i++) { - sumFlow+=_state[i]*_flow[i]; - if (_state[i]==STATE_TREE) { - std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n"; - } - } - std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n"; -#endif // Check feasibility if( retVal == OPTIMAL){ for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) { diff --git a/ot/solvers.py b/ot/solvers.py index 68b389d63..4e1238d78 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1749,6 +1749,38 @@ def solve_sample( return res + elif ( + lazy + and method is None + and (reg is None or reg == 0) + and unbalanced is None + and X_a is not None + and X_b is not None + ): + # Use lazy EMD solver with coordinates (no regularization, balanced) + value_linear, log = emd2( + a, + b, + M=None, + X_a=X_a, + X_b=X_b, + metric=metric, + numItermax=max_iter if max_iter is not None else 100000, + log=True, + return_matrix=True, + numThreads=n_threads, + ) + + res = OTResult( + potentials=(log["u"], log["v"]), + value=value_linear, + value_linear=value_linear, + plan=log["G"], + status=log["warning"] if log["warning"] is not None else "Converged", + ) + + return res + else: # Detect backend nx = get_backend(X_a, X_b, a, b) diff --git a/test/test_solvers.py b/test/test_solvers.py index 040b38dc6..736cfa94d 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -53,8 +53,8 @@ "method": "nystroem", "metric": "euclidean", }, # fail nystroem on metric not euclidean - {"lazy": True}, # fail lazy for non regularized - {"lazy": True, "unbalanced": 1}, # fail lazy for non regularized unbalanced + # Note: {"lazy": True} now works - lazy EMD solver implemented + {"lazy": True, "unbalanced": 1}, # fail lazy for unbalanced (not supported) { "lazy": True, "reg": 1, @@ -601,6 +601,75 @@ def test_solve_sample_lazy(nx): np.testing.assert_allclose(sol0.plan, sol.lazy_plan[:], rtol=1e-5, atol=1e-5) +def test_solve_sample_lazy_emd(nx): + # test lazy EMD solver (no regularization, computes distances on-the-fly) + n_s = 20 + n_t = 25 + d = 2 + rng = np.random.RandomState(42) + + X_s = rng.rand(n_s, d) + X_t = rng.rand(n_t, d) + a = ot.utils.unif(n_s) + b = ot.utils.unif(n_t) + + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + + # Test all supported metrics + for metric in ["sqeuclidean", "euclidean", "cityblock"]: + # Standard solver: pre-compute distance matrix + M = ot.dist(X_sb, X_tb, metric=metric) + sol_standard = ot.solve(M, ab, bb) + + # Lazy solver: compute distances on-the-fly + sol_lazy = ot.solve_sample(X_sb, X_tb, ab, bb, lazy=True, metric=metric) + + # Check that results match + np.testing.assert_allclose( + nx.to_numpy(sol_standard.value), + nx.to_numpy(sol_lazy.value), + rtol=1e-10, + atol=1e-10, + err_msg=f"Lazy EMD cost mismatch for metric {metric}", + ) + + np.testing.assert_allclose( + nx.to_numpy(sol_standard.plan), + nx.to_numpy(sol_lazy.plan), + rtol=1e-10, + atol=1e-10, + err_msg=f"Lazy EMD plan mismatch for metric {metric}", + ) + + # Test larger problem to verify memory savings benefit + n_large = 100 + X_s_large = rng.rand(n_large, d) + X_t_large = rng.rand(n_large, d) + a_large = ot.utils.unif(n_large) + b_large = ot.utils.unif(n_large) + + X_sb_large, X_tb_large, ab_large, bb_large = nx.from_numpy( + X_s_large, X_t_large, a_large, b_large + ) + + # Standard solver + M_large = ot.dist(X_sb_large, X_tb_large, metric="sqeuclidean") + sol_standard_large = ot.solve(M_large, ab_large, bb_large) + + # Lazy solver (avoids storing 100x100 cost matrix) + sol_lazy_large = ot.solve_sample( + X_sb_large, X_tb_large, ab_large, bb_large, lazy=True, metric="sqeuclidean" + ) + + np.testing.assert_allclose( + nx.to_numpy(sol_standard_large.value), + nx.to_numpy(sol_lazy_large.value), + rtol=1e-9, + atol=1e-9, + err_msg="Lazy EMD cost mismatch for large problem", + ) + + @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") @pytest.mark.skipif(not geomloss, reason="pytorch not installed") @pytest.skip_backend("tf") From 22ab1ff2e66657a3e447e773651f59a99247be38 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Tue, 20 Jan 2026 10:24:03 +0100 Subject: [PATCH 2/8] Add emd2_lazy function and fix SciPy sparse matrix compatibility --- ot/__init__.py | 2 + ot/gromov/_estimators.py | 25 ++-- ot/lp/__init__.py | 3 +- ot/lp/_network_simplex.py | 255 +++++++++++++++++++++++++++++--------- ot/solvers.py | 10 +- test/test_solvers.py | 54 ++++---- 6 files changed, 246 insertions(+), 103 deletions(-) diff --git a/ot/__init__.py b/ot/__init__.py index 8a389bb98..26f428aa1 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -41,6 +41,7 @@ from .lp import ( emd, emd2, + emd2_lazy, emd_1d, emd2_1d, wasserstein_1d, @@ -82,6 +83,7 @@ __all__ = [ "emd", "emd2", + "emd2_lazy", "emd_1d", "sinkhorn", "sinkhorn2", diff --git a/ot/gromov/_estimators.py b/ot/gromov/_estimators.py index 14871bfe3..18359afe5 100644 --- a/ot/gromov/_estimators.py +++ b/ot/gromov/_estimators.py @@ -122,8 +122,8 @@ def GW_distance_estimation( for i in range(nb_samples_p): if nx.issparse(T): - T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,)) - T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,)) + T_indexi = nx.reshape(nx.todense(T[[index_i[i]], :]), (-1,)) + T_indexj = nx.reshape(nx.todense(T[[index_j[i]], :]), (-1,)) else: T_indexi = T[index_i[i], :] T_indexj = T[index_j[i], :] @@ -243,16 +243,18 @@ def pointwise_gromov_wasserstein( index = np.zeros(2, dtype=int) # Initialize with default marginal - index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) - index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q)) + index[0] = int(generator.choice(len_p, size=1, p=nx.to_numpy(p)).item()) + index[1] = int(generator.choice(len_q, size=1, p=nx.to_numpy(q)).item()) T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) best_gw_dist_estimated = np.inf for cpt in range(max_iter): - index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p)) - T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,)) - index[1] = generator.choice( - len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0)) + index[0] = int(generator.choice(len_p, size=1, p=nx.to_numpy(p)).item()) + T_index0 = nx.reshape(nx.todense(T[[index[0]], :]), (-1,)) + index[1] = int( + generator.choice( + len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0)) + ).item() ) if alpha == 1: @@ -404,10 +406,15 @@ def sampled_gromov_wasserstein( ) Lik = 0 for i, index0_i in enumerate(index0): + T_row = ( + nx.reshape(nx.todense(T[[index0_i], :]), (-1,)) + if nx.issparse(T) + else T[index0_i, :] + ) index1 = generator.choice( len_q, size=nb_samples_grad_q, - p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])), + p=nx.to_numpy(T_row / nx.sum(T_row)), replace=False, ) # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index f8924a322..8e88d63c8 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -9,7 +9,7 @@ # License: MIT License from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize -from ._network_simplex import emd, emd2 +from ._network_simplex import emd, emd2, emd2_lazy from ._barycenter_solvers import ( barycenter, free_support_barycenter, @@ -35,6 +35,7 @@ __all__ = [ "emd", "emd2", + "emd2_lazy", "barycenter", "free_support_barycenter", "cvx", diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index 9a4969a89..7cce0113c 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -440,10 +440,7 @@ def emd( def emd2( a, b, - M=None, - X_a=None, - X_b=None, - metric="sqeuclidean", + M, processes=1, numItermax=100000, log=False, @@ -474,7 +471,7 @@ def emd2( .. note:: This function will cast the computed transport plan and transportation loss to the data type of the provided input with the - following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`, + following priority : :math:`\mathbf{a}`, then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided. Casting to an integer tensor might result in a loss of precision. If this behaviour is unwanted, please make sure to provide a @@ -490,23 +487,13 @@ def emd2( Source histogram (uniform weight if empty list) b : (nt,) array-like, float64 Target histogram (uniform weight if empty list) - M : (ns,nt) array-like or sparse matrix, float64, optional + M : (ns,nt) array-like or sparse matrix, float64 Loss matrix. Can be: - Dense array (c-order array in numpy with type float64) - Sparse matrix in backend's format (scipy.sparse.coo_matrix for NumPy backend, torch.sparse_coo_tensor for PyTorch backend, etc.) - Either M or (X_a, X_b) must be provided. - X_a : (ns, dim) array-like, float64, optional - Source coordinates for lazy cost computation. - If provided along with X_b, costs will be computed on-the-fly. - X_b : (nt, dim) array-like, float64, optional - Target coordinates for lazy cost computation. - If provided along with X_a, costs will be computed on-the-fly. - metric : str, optional (default='sqeuclidean') - Distance metric for lazy mode. Options: 'sqeuclidean', 'euclidean', 'cityblock'. - Only used when X_a and X_b are provided. processes : int, optional (default=1) Nb of processes used for multiple emd computation (deprecated) numItermax : int, optional (default=100000) @@ -577,25 +564,11 @@ def emd2( n1, n2 = None, None - # Check if we're using lazy mode with coordinates - use_lazy = X_a is not None and X_b is not None and M is None + a, b, M = list_to_array(a, b, M) + nx = get_backend(a, b, M) - if use_lazy: - # Lazy mode: coordinates provided - a, b, X_a, X_b = list_to_array(a, b, X_a, X_b) - nx = get_backend(a, b, X_a, X_b) - n1, n2 = X_a.shape[0], X_b.shape[0] - is_sparse = False - else: - # Standard mode: cost matrix provided - if M is None: - raise ValueError( - "Either M (cost matrix) or (X_a, X_b) coordinates must be provided" - ) - a, b, M = list_to_array(a, b, M) - nx = get_backend(a, b, M) - # Check if M is sparse using backend's issparse method - is_sparse = nx.issparse(M) + # Check if M is sparse using backend's issparse method + is_sparse = nx.issparse(M) # Save original sparse tensor for gradient tracking (before conversion to numpy) M_original_sparse = None @@ -622,9 +595,6 @@ def emd2( type_as = a elif b is not None and len(b) != 0: type_as = b - elif use_lazy: - # In lazy mode, use coordinates for type inference - type_as = X_a else: type_as = a if a is not None else b @@ -680,10 +650,7 @@ def emd2( def f(b): bsel = b != 0 - if use_lazy: - # Solve with lazy cost computation from coordinates - G, cost, u, v, result_code = emd_c_lazy(a, b, X_a, X_b, metric, numItermax) - elif is_sparse: + if is_sparse: # Solve sparse EMD flow_sources, flow_targets, flow_values, cost, u, v, result_code = ( emd_c_sparse(a, b, edge_sources, edge_targets, edge_costs, numItermax) @@ -757,27 +724,6 @@ def f(b): shape=(n1, n2), type_as=type_as, ) - elif use_lazy: - # Lazy case: warn about integer casting - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) - - G_backend = nx.from_numpy(G, type_as=type_as) - # For now, just set gradients wrt marginals - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0), - ( - nx.from_numpy(u - np.mean(u), type_as=type_as), - nx.from_numpy(v - np.mean(v), type_as=type_as), - ), - ) else: # Dense case: warn about integer casting if not nx.is_floating_point(type_as): @@ -829,3 +775,188 @@ def f(b): res = list(map(f, [b[:, i].copy() for i in range(nb)])) return res + + +def emd2_lazy( + a, + b, + X_a, + X_b, + metric="sqeuclidean", + numItermax=100000, + log=False, + return_matrix=False, + center_dual=True, + check_marginals=True, +): + r"""Solves the Earth Movers distance problem with lazy cost computation and returns the loss + + .. math:: + \min_\gamma \quad \langle \gamma, \mathbf{M}(\mathbf{X}_a, \mathbf{X}_b) \rangle_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`\mathbf{M}(\mathbf{X}_a, \mathbf{X}_b)` is computed on-the-fly from coordinates + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + .. note:: This function computes distances on-the-fly during the network simplex algorithm, + avoiding the O(ns*nt) memory cost of pre-computing the full cost matrix. Memory usage + is O(ns+nt) instead. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Parameters + ---------- + a : (ns,) array-like, float64 + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float64 + Target histogram (uniform weight if empty list) + X_a : (ns, dim) array-like, float64 + Source sample coordinates + X_b : (nt, dim) array-like, float64 + Target sample coordinates + metric : str, optional (default='sqeuclidean') + Distance metric for cost computation. Options: + + - 'sqeuclidean': Squared Euclidean distance + - 'euclidean': Euclidean distance + - 'cityblock': Manhattan/L1 distance + + numItermax : int, optional (default=100000) + Maximum number of iterations before stopping if not converged + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost, dual variables, + and optionally the transport plan (sparse format) + return_matrix: boolean, optional (default=False) + If True, returns the optimal transportation matrix in the log (sparse format) + center_dual: boolean, optional (default=True) + If True, centers the dual potential using :py:func:`ot.lp.center_ot_dual` + check_marginals: bool, optional (default=True) + If True, checks that the marginals mass are equal + + Returns + ------- + W: float + Optimal transportation loss + log: dict + If input log is True, a dictionary containing: + + - cost: the optimal transportation cost + - u, v: dual variables + - warning: solver status message + - result_code: solver return code + - G: (optional) sparse transport plan if return_matrix=True + + See Also + -------- + ot.emd2 : EMD with pre-computed cost matrix + ot.lp.emd_c_lazy : Low-level C++ lazy solver + """ + + a, b, X_a, X_b = list_to_array(a, b, X_a, X_b) + nx = get_backend(a, b, X_a, X_b) + + n1, n2 = X_a.shape[0], X_b.shape[0] + + # Validate dimensions match + if X_a.shape[1] != X_b.shape[1]: + raise ValueError( + f"X_a and X_b must have the same number of dimensions, " + f"got {X_a.shape[1]} and {X_b.shape[1]}" + ) + + if a is not None and len(a) != 0: + type_as = a + elif b is not None and len(b) != 0: + type_as = b + else: + type_as = X_a + + # if empty array given then use uniform distributions + if a is None or len(a) == 0: + a = nx.ones((n1,), type_as=type_as) / n1 + if b is None or len(b) == 0: + b = nx.ones((n2,), type_as=type_as) / n2 + + a0, b0 = a, b + + # Convert to numpy for C++ backend + X_a_np = nx.to_numpy(X_a) + X_b_np = nx.to_numpy(X_b) + a_np = nx.to_numpy(a) + b_np = nx.to_numpy(b) + + X_a_np = np.asarray(X_a_np, dtype=np.float64, order="C") + X_b_np = np.asarray(X_b_np, dtype=np.float64, order="C") + a_np = np.asarray(a_np, dtype=np.float64) + b_np = np.asarray(b_np, dtype=np.float64) + + assert ( + a_np.shape[0] == n1 and b_np.shape[0] == n2 + ), "Dimension mismatch, check dimensions of X_a/X_b with a and b" + + # ensure that same mass + if check_marginals: + np.testing.assert_almost_equal( + a_np.sum(0), + b_np.sum(0, keepdims=True), + err_msg="a and b vector must have the same sum", + decimal=6, + ) + b_np = b_np * a_np.sum(0) / b_np.sum(0, keepdims=True) + + # Solve with lazy cost computation + G, cost, u, v, result_code = emd_c_lazy( + a_np, b_np, X_a_np, X_b_np, metric, numItermax + ) + + # Center dual potentials + if center_dual: + u, v = center_ot_dual(u, v, a_np, b_np) + + # Convert sparse plan to backend format + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + + G_backend = nx.from_numpy(G, type_as=type_as) + + # Set gradients wrt marginals + cost_backend = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + ), + ) + + check_result(result_code) + + # Return results + if log or return_matrix: + log_dict = { + "cost": cost_backend, + "u": nx.from_numpy(u, type_as=type_as), + "v": nx.from_numpy(v, type_as=type_as), + "warning": check_result(result_code), + "result_code": result_code, + } + if return_matrix: + log_dict["G"] = G_backend + return cost_backend, log_dict + else: + return cost_backend diff --git a/ot/solvers.py b/ot/solvers.py index 4e1238d78..c85f2dd76 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -8,7 +8,7 @@ # License: MIT License from .utils import OTResult, dist -from .lp import emd2, wasserstein_1d +from .lp import emd2, emd2_lazy, wasserstein_1d from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced from .bregman import ( @@ -1758,17 +1758,15 @@ def solve_sample( and X_b is not None ): # Use lazy EMD solver with coordinates (no regularization, balanced) - value_linear, log = emd2( + value_linear, log = emd2_lazy( a, b, - M=None, - X_a=X_a, - X_b=X_b, + X_a, + X_b, metric=metric, numItermax=max_iter if max_iter is not None else 100000, log=True, return_matrix=True, - numThreads=n_threads, ) res = OTResult( diff --git a/test/test_solvers.py b/test/test_solvers.py index 736cfa94d..548ce4a05 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -601,7 +601,8 @@ def test_solve_sample_lazy(nx): np.testing.assert_allclose(sol0.plan, sol.lazy_plan[:], rtol=1e-5, atol=1e-5) -def test_solve_sample_lazy_emd(nx): +@pytest.mark.parametrize("metric", ["sqeuclidean", "euclidean", "cityblock"]) +def test_solve_sample_lazy_emd(nx, metric): # test lazy EMD solver (no regularization, computes distances on-the-fly) n_s = 20 n_t = 25 @@ -615,34 +616,37 @@ def test_solve_sample_lazy_emd(nx): X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) - # Test all supported metrics - for metric in ["sqeuclidean", "euclidean", "cityblock"]: - # Standard solver: pre-compute distance matrix - M = ot.dist(X_sb, X_tb, metric=metric) - sol_standard = ot.solve(M, ab, bb) - - # Lazy solver: compute distances on-the-fly - sol_lazy = ot.solve_sample(X_sb, X_tb, ab, bb, lazy=True, metric=metric) - - # Check that results match - np.testing.assert_allclose( - nx.to_numpy(sol_standard.value), - nx.to_numpy(sol_lazy.value), - rtol=1e-10, - atol=1e-10, - err_msg=f"Lazy EMD cost mismatch for metric {metric}", - ) + # Standard solver: pre-compute distance matrix + M = ot.dist(X_sb, X_tb, metric=metric) + sol_standard = ot.solve(M, ab, bb) + + # Lazy solver: compute distances on-the-fly + sol_lazy = ot.solve_sample(X_sb, X_tb, ab, bb, lazy=True, metric=metric) + + # Check that results match + np.testing.assert_allclose( + nx.to_numpy(sol_standard.value), + nx.to_numpy(sol_lazy.value), + rtol=1e-10, + atol=1e-10, + err_msg=f"Lazy EMD cost mismatch for metric {metric}", + ) + + np.testing.assert_allclose( + nx.to_numpy(sol_standard.plan), + nx.to_numpy(sol_lazy.plan), + rtol=1e-10, + atol=1e-10, + err_msg=f"Lazy EMD plan mismatch for metric {metric}", + ) - np.testing.assert_allclose( - nx.to_numpy(sol_standard.plan), - nx.to_numpy(sol_lazy.plan), - rtol=1e-10, - atol=1e-10, - err_msg=f"Lazy EMD plan mismatch for metric {metric}", - ) +def test_solve_sample_lazy_emd_large(nx): # Test larger problem to verify memory savings benefit n_large = 100 + d = 2 + rng = np.random.RandomState(42) + X_s_large = rng.rand(n_large, d) X_t_large = rng.rand(n_large, d) a_large = ot.utils.unif(n_large) From 3a67a4f08ecc62cedae1be37da1ca6c0c135c5d5 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Tue, 20 Jan 2026 10:47:42 +0100 Subject: [PATCH 3/8] small fix errors not appearing locally --- ot/regpath.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ot/regpath.py b/ot/regpath.py index e64ca7c77..aedc35b88 100644 --- a/ot/regpath.py +++ b/ot/regpath.py @@ -486,6 +486,9 @@ def complement_schur(M_current, b, d, id_pop): else: X = M_current.dot(b) s = d - b.T.dot(X) + # Ensure s is a scalar (extract from array if needed) + if np.ndim(s) > 0: + s = s.item() M = np.zeros((n, n)) M[:-1, :-1] = M_current + X.dot(X.T) / s X_ravel = X.ravel() From cd610707dda74be3e0d8be883bac19cdfef84435 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Tue, 20 Jan 2026 10:58:50 +0100 Subject: [PATCH 4/8] Fix SciPy version compatibility for distance metrics --- ot/utils.py | 4 +++- test/test_utils.py | 25 ++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/ot/utils.py b/ot/utils.py index cc9de4f02..64bf1ace9 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -429,7 +429,9 @@ def dist( else: if isinstance(metric, str) and metric.endswith("minkowski"): return cdist(x1, x2, metric=metric, p=p, w=w) - if w is not None: + # Only pass w parameter for metrics that support it + # According to SciPy docs, only 'minkowski' and 'wminkowski' support w + if w is not None and metric in ["minkowski", "wminkowski"]: return cdist(x1, x2, metric=metric, w=w) return cdist(x1, x2, metric=metric) diff --git a/test/test_utils.py b/test/test_utils.py index 0b2769109..8c5e65b93 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -19,7 +19,7 @@ "correlation", ] -lst_all_metrics = lst_metrics + [ +lst_all_metrics_candidates = lst_metrics + [ "braycurtis", "canberra", "chebyshev", @@ -34,6 +34,18 @@ "yule", ] +# Filter to only include metrics available in current SciPy version +# (some metrics like sokalmichener were removed in newer SciPy versions) +lst_all_metrics = [] +for metric in lst_all_metrics_candidates: + try: + scipy.spatial.distance.cdist( + np.array([[0, 0]]), np.array([[1, 1]]), metric=metric + ) + lst_all_metrics.append(metric) + except ValueError: + pass + def get_LazyTensor(nx): n1 = 100 @@ -240,7 +252,18 @@ def test_dist(): "seuclidean", ] # do not support weights depending on scipy's version + # Filter out metrics not available in current scipy version + from scipy.spatial.distance import cdist + + available_metrics_w = [] for metric in metrics_w: + try: + cdist(x[:2], x[:2], metric=metric) + available_metrics_w.append(metric) + except ValueError: + pass + + for metric in available_metrics_w: print(metric) ot.dist(x, x, metric=metric, p=3, w=rng.random((2,))) ot.dist( From 0297df11edc20e480a479f2e44d8d40e13ebe0ac Mon Sep 17 00:00:00 2001 From: nathanneike Date: Thu, 22 Jan 2026 10:47:48 +0100 Subject: [PATCH 5/8] fixed issues added Release info --- RELEASES.md | 10 +++++++--- ot/lp/EMD_wrapper.cpp | 18 ++++++++---------- ot/lp/_network_simplex.py | 28 ++++++++++------------------ ot/lp/emd_wrap.pyx | 19 +++++++++++-------- ot/lp/network_simplex_simple.h | 22 ++++++++++++++++++---- ot/mapping.py | 8 ++++++-- ot/solvers.py | 4 ++-- test/test_solvers.py | 1 - 8 files changed, 62 insertions(+), 48 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index cbe6c84fe..e9e569e26 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -2,15 +2,19 @@ ## 0.9.7.dev0 -This new release adds support for sparse cost matrices in the exact EMD solver. Users can now pass sparse cost matrices (e.g., k-NN graphs, sparse graphs) and receive sparse transport plans, significantly reducing memory footprint for large-scale problems. The implementation is backend-agnostic, automatically handling scipy.sparse for NumPy and torch.sparse for PyTorch, and preserves full gradient computation capabilities for automatic differentiation in PyTorch. This enables efficient solving of OT problems on graphs with millions of nodes where only a sparse subset of edges have finite costs. +This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation. #### New features +- Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788) +- Optimize EMD solver with bulk cost array copy (PR #788) - Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782) -- Geomloss function now handles both scalar and slice indices for i and j. Using backend agnostic reshaping. Allows to do plan[i,:] and plan[:,j] (PR #785) +- Geomloss function now handles both scalar and slice indices for i and j (PR #785) - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) #### Closed issues -- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration (PR #785) +- Fix NumPy 2.x compatibility in Brenier potential bounds (PR #788) +- Fix MSVC Windows build by removing __restrict__ keyword (PR #788) +- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration (PR #785) - Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770) - Add test for build from source (PR #772, Issue #764) - Fix device for batch Ot solver in `ot.batch` (PR #784, Issue #783) diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 6aa27897a..c458d0f49 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -83,15 +83,15 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m); - // Set the cost of each edge + // Set the cost of each edge using bulk copy for efficiency + std::vector cost_array(n * m); int64_t idarc = 0; for (uint64_t i=0; i cost_array(n * m); int64_t idarc = 0; for (uint64_t i=0; i(*this) + NetworkSimplexSimple& setCostArray(const Cost* cost_array) { + std::copy(cost_array, cost_array + _arc_num, _cost.begin()); + return *this; + } + /// \brief Enable lazy cost computation from coordinates. /// /// This function enables lazy cost computation where distances are @@ -635,8 +649,8 @@ namespace lemon { /// /// \return Cost (distance) between the two points inline Cost computeLazyCost(int i, int j_adjusted) const { - const double* __restrict__ xa = _coords_a + i * _dim; - const double* __restrict__ xb = _coords_b + j_adjusted * _dim; + const double* xa = _coords_a + i * _dim; + const double* xb = _coords_b + j_adjusted * _dim; Cost cost = 0; if (_metric == 0) { // sqeuclidean diff --git a/ot/mapping.py b/ot/mapping.py index cc3e6cd57..bb1a78e8a 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -376,7 +376,9 @@ def nearest_brenier_potential_predict_bounds( ] problem = cvx.Problem(objective, constraints) problem.solve(solver=solver) - phi_lu[0, y_idx] = phi_l_y.value + phi_lu[0, y_idx] = ( + phi_l_y.value if np.isscalar(phi_l_y.value) else phi_l_y.value.item() + ) G_lu[0, y_idx] = G_l_y.value if log: log_item["l"] = { @@ -403,7 +405,9 @@ def nearest_brenier_potential_predict_bounds( ] problem = cvx.Problem(objective, constraints) problem.solve(solver=solver) - phi_lu[1, y_idx] = phi_u_y.value + phi_lu[1, y_idx] = ( + phi_u_y.value if np.isscalar(phi_u_y.value) else phi_u_y.value.item() + ) G_lu[1, y_idx] = G_u_y.value if log: log_item["u"] = { diff --git a/ot/solvers.py b/ot/solvers.py index c85f2dd76..99314da02 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1759,10 +1759,10 @@ def solve_sample( ): # Use lazy EMD solver with coordinates (no regularization, balanced) value_linear, log = emd2_lazy( - a, - b, X_a, X_b, + a, + b, metric=metric, numItermax=max_iter if max_iter is not None else 100000, log=True, diff --git a/test/test_solvers.py b/test/test_solvers.py index 548ce4a05..9b0ababbf 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -53,7 +53,6 @@ "method": "nystroem", "metric": "euclidean", }, # fail nystroem on metric not euclidean - # Note: {"lazy": True} now works - lazy EMD solver implemented {"lazy": True, "unbalanced": 1}, # fail lazy for unbalanced (not supported) { "lazy": True, From 1b14aa2406e4eac90e3f2064d80a5c294a45eab6 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Thu, 22 Jan 2026 10:52:16 +0100 Subject: [PATCH 6/8] Modified OpenMP implementation to allow build --- ot/lp/network_simplex_simple_omp.h | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h index 890b7ab03..5222372b1 100644 --- a/ot/lp/network_simplex_simple_omp.h +++ b/ot/lp/network_simplex_simple_omp.h @@ -703,6 +703,20 @@ namespace lemon_omp { return *this; } + /// \brief Set all cost values at once from an array. + /// + /// This function sets all arc costs from a vector in one bulk operation. + /// More efficient than calling setCost() for each arc individually. + /// + /// \param cost_array Vector containing all arc costs in order + /// + /// \return (*this) + template + NetworkSimplexSimple& setCostArray(const std::vector& cost_array) { + std::copy(cost_array.begin(), cost_array.end(), _cost.begin()); + return *this; + } + /// \brief Set the supply values of the nodes. /// From 069508a2b22b511095a7851f47fb9daf5a3ec13e Mon Sep 17 00:00:00 2001 From: nathanneike Date: Thu, 22 Jan 2026 11:44:27 +0100 Subject: [PATCH 7/8] Removed set cost array with bulk --- ot/lp/EMD_wrapper.cpp | 18 ++++++++++-------- ot/lp/network_simplex_simple.h | 14 -------------- ot/lp/network_simplex_simple_omp.h | 15 --------------- 3 files changed, 10 insertions(+), 37 deletions(-) diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index c458d0f49..6aa27897a 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -83,15 +83,15 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m); - // Set the cost of each edge using bulk copy for efficiency - std::vector cost_array(n * m); + // Set the cost of each edge int64_t idarc = 0; for (uint64_t i=0; i cost_array(n * m); + // Set the cost of each edge int64_t idarc = 0; for (uint64_t i=0; i(*this) - NetworkSimplexSimple& setCostArray(const Cost* cost_array) { - std::copy(cost_array, cost_array + _arc_num, _cost.begin()); - return *this; - } - /// \brief Enable lazy cost computation from coordinates. /// /// This function enables lazy cost computation where distances are diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h index 5222372b1..d8ef672e3 100644 --- a/ot/lp/network_simplex_simple_omp.h +++ b/ot/lp/network_simplex_simple_omp.h @@ -703,21 +703,6 @@ namespace lemon_omp { return *this; } - /// \brief Set all cost values at once from an array. - /// - /// This function sets all arc costs from a vector in one bulk operation. - /// More efficient than calling setCost() for each arc individually. - /// - /// \param cost_array Vector containing all arc costs in order - /// - /// \return (*this) - template - NetworkSimplexSimple& setCostArray(const std::vector& cost_array) { - std::copy(cost_array.begin(), cost_array.end(), _cost.begin()); - return *this; - } - - /// \brief Set the supply values of the nodes. /// /// This function sets the supply values of the nodes. From 58c520c9ba9f5a34e072ad68ea90e6cef92da331 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Thu, 22 Jan 2026 13:34:51 +0100 Subject: [PATCH 8/8] Updated release --- RELEASES.md | 1 - 1 file changed, 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index e9e569e26..123b11f8d 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,7 +6,6 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver #### New features - Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788) -- Optimize EMD solver with bulk cost array copy (PR #788) - Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782) - Geomloss function now handles both scalar and slice indices for i and j (PR #785) - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)