diff --git a/ot/__init__.py b/ot/__init__.py index f16b6fcfc..9c33e9feb 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -36,6 +36,7 @@ from . import solvers from . import gaussian + # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, binary_search_circle, wasserstein_circle, @@ -50,7 +51,7 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov +from .solvers import solve, solve_gromov, solve_sample # utils functions from .utils import dist, unif, tic, toc, toq @@ -65,7 +66,7 @@ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', - 'factored_optimal_transport', 'solve', 'solve_gromov', + 'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/solvers.py b/ot/solvers.py index 0313cf588..4b261b5cc 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -7,11 +7,11 @@ # # License: MIT License -from .utils import OTResult +from .utils import OTResult, unif, dist from .lp import emd2 from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced -from .bregman import sinkhorn_log +from .bregman import sinkhorn_log, empirical_sinkhorn from .partial import partial_wasserstein_lagrange from .smooth import smooth_ot_dual from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, @@ -21,6 +21,9 @@ entropic_semirelaxed_gromov_wasserstein2) from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 + + + #, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2 @@ -848,3 +851,142 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) return res + + + + +##### new ot.solve_sample function + +def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False): + + r"""Solve the discrete optimal transport problem using the samples in the source and target domains. + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + is_Lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by default False + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + """ + + # Detect backend + arr = [X_s,X_t] + if a is not None: + arr.append(a) + if b is not None: + arr.append(b) + nx = get_backend(*arr) + + # create uniform weights if not given + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + if metric != 'sqeuclidean': + raise (NotImplementedError('Not implemented metric = {} (only sqeulidean)'.format(metric))) + + + # default values for solutions + potentials = None + lazy_plan = None + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + if is_Lazy: + ################# WIP #################### + if reg is None or reg == 0: # EMD solver for isLazy ? + + if unbalanced is None: # balanced EMD solver for isLazy ? + raise (NotImplementedError('Not implemented balanced with no regularization')) + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) + + + ############################################# + + else: + if unbalanced is None: + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + # compute potentials + potentials = (log["u"], log["v"]) + + # compute lazy_plan + # ... + + raise (NotImplementedError('Not implemented balanced with regularization')) + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) + + else: + # compute cost matrix M and use solve function + M = dist(X_s, X_t, metric) + + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + return res \ No newline at end of file diff --git a/ot/utils.py b/ot/utils.py index 0936648ca..2f4cfc9e7 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1168,7 +1168,6 @@ def citation(self): } """ - class LazyTensor(object): """ A lazy tensor is a tensor that is not stored in memory. Instead, it is defined by a function that computes its values on the fly from slices. @@ -1233,4 +1232,4 @@ def __getitem__(self, key): return self._getitem(*k, **self.kwargs) def __repr__(self): - return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) + return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) \ No newline at end of file diff --git a/test/test_solvers.py b/test/test_solvers.py index f0f5b638f..18572b90a 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -255,3 +255,93 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) with pytest.raises(NotImplementedError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) + + + + +######## Test functions for ot.solve_sample ######## + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_sample(nx): + # test solve_sample when is_Lazy = False + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') + + + +def test_lazy_solve_sample(nx): + # test solve_sample when is_Lazy = True + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) + + # check some attributes + sol.potentials + sol.lazy_plan + + assert_allclose_sol(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) + + assert_allclose_sol(sol, solb) + + # test not implemented reg==0 (or None) + balanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default + + # test not implemented reg==0 (or None) + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default + + # test not implemented reg != 0 + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) + + + + + diff --git a/test/test_utils.py b/test/test_utils.py index 3a9d590ab..942f403ce 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -569,3 +569,4 @@ def test_lowrank_LazyTensor(nx): T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx) np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0)) + \ No newline at end of file