diff --git a/ot/__init__.py b/ot/__init__.py index f16b6fcfc..cb00f4553 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,6 +35,7 @@ from . import factored from . import solvers from . import gaussian +from . import lowrank # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, @@ -50,7 +51,8 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov +from .solvers import solve, solve_gromov, solve_sample +from .lowrank import lowrank_sinkhorn # utils functions from .utils import dist, unif, tic, toc, toq diff --git a/ot/lowrank.py b/ot/lowrank.py new file mode 100644 index 000000000..22ff8b754 --- /dev/null +++ b/ot/lowrank.py @@ -0,0 +1,210 @@ +################################################################################################################# +############################################## WORK IN PROGRESS ################################################# +################################################################################################################# + +## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms + +import warnings +from .utils import unif, list_to_array, dist +from .backend import get_backend + + + +################################## LR-DYSKTRA ALGORITHM ########################################## + +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): + """ + Implementation of the Dykstra algorithm for low rank sinkhorn + """ + + # get dykstra parameters + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p + + # POT backend + eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) + q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2) + + nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2) + + # ------- Dykstra algorithm ------ + g_ = eps3 + + u1 = p1 / nx.dot(eps1, v1_) + u2 = p2 / nx.dot(eps2, v2_) + + g = nx.maximum(alpha, g_ * q3_1) + q3_1 = (g_ * q3_1) / g + g_ = g + + prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) + prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) + g = (g_ * q3_2 * prod1 * prod2)**(1/3) + + v1 = g / nx.dot(eps1.T,u1) + v2 = g / nx.dot(eps2.T,u2) + + q1 = (v1_ * q1) / v1 + q2 = (v2_ * q2) / v2 + q3_2 = (g_ * q3_2) / g + + v1_, v2_ = v1, v2 + g_ = g + + # Compute error + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) + err = err1 + err2 + + # Compute low rank matrices Q, R + Q = u1[:,None] * eps1 * v1[None,:] + R = u2[:,None] * eps2 * v2[None,:] + + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] + + return Q, R, g, err, dykstra_p + + + +#################################### LOW RANK SINKHORN ALGORITHM ######################################### + + +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidean', alpha="auto", + numItermax=10000, stopThr=1e-9, warn=True, verbose=False): + r''' + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. + + This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + reg : float, optional + Regularization term >0 + rank: int, optional + Nonnegative rank of the OT plan + alpha: int, optional + Lower bound for the weight vector g (>0 and <1/r) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + + + Returns + ------- + Q : array-like, shape (n_samples_a, r) + First low-rank matrix decomposition of the OT plan + R: array-like, shape (n_samples_b, r) + Second low-rank matrix decomposition of the OT plan + g : array-like, shape (r, ) + Weight vector for the low-rank decomposition of the OT plan + + + References + ---------- + + .. Scetbon, M., Cuturi, M., & Peyré, G (2021). + Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. + + ''' + + X_s, X_t = list_to_array(X_s, X_t) + nx = get_backend(X_s, X_t) + + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + # Compute cost matrix + M = dist(X_s,X_t, metric=metric) + + # Compute rank + rank = min(ns, nt, rank) + r = rank + + if alpha == 'auto': + alpha = 1.0 / (r + 1) + + if (1/r < alpha) or (alpha < 0): + warnings.warn("The provided alpha value might lead to instabilities.") + + + # Compute gamma + L = nx.sqrt((2/(alpha**4))*(nx.norm(M)**2) + (reg + (2/(alpha**3))*(nx.norm(M))**2)) + gamma = 1/(2*L) + + # Initialisation + Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) + q3_1, q3_2 = nx.ones(r), nx.ones(r) + v1_, v2_ = nx.ones(r), nx.ones(r) + q1, q2 = nx.ones(r), nx.ones(r) + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] + err = 1 + + for ii in range(numItermax): + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) + + if err < stopThr: + break + + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + + return Q, R, g + + + + + +############################################################################ +## Test with X_s, X_t from ot.datasets +############################################################################# + +# import numpy as np +# import ot + +# Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +# Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) + +# ns = Xs.shape[0] +# nt = Xt.shape[0] + +# a = unif(ns) +# b = unif(nt) + +# Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, metric='euclidean', verbose=True, numItermax=100) +# M = ot.dist(Xs,Xt) +# P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + +# print(np.sum(P)) + + + + diff --git a/ot/solvers.py b/ot/solvers.py index 0313cf588..c176969ca 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -848,3 +848,169 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) return res + + + + + + +################################## WORK IN PROGRESS ##################################### + +## Implementation of the ot.solve_sample function +## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases) + + +from .utils import unif, list_to_array, dist, OTResultLazy +from .bregman import empirical_sinkhorn + + +def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False): + + r"""Solve the discrete optimal transport problem using the samples in the source and target domains. + It returns either a :any:`OTResult` or :any:`OTResultLazy` object. + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + is_Lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by default False + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res_lazy : OTResultLazy() + Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. + The information can be obtained as follows: + + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t + - res.potentials : OT dual potentials + + See :any:`OTResultLazy` for more information. + + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + + """ + + X_s, X_t = list_to_array(X_s,X_t) + + # detect backend + arr = [X_s,X_t] + if a is not None: + arr.append(a) + if b is not None: + arr.append(b) + nx = get_backend(*arr) + + # create uniform weights if not given + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + # default values for solutions + potentials = None + lazy_plan = None + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + if is_Lazy: + ################# WIP #################### + if reg is None or reg == 0: # EMD solver for isLazy ? + + if unbalanced is None: # balanced EMD solver for isLazy ? + raise (NotImplementedError('Not implemented balanced with no regularization')) + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) + + + ############################################# + + else: + if unbalanced is None: + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + # compute potentials + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) + + else: + # compute cost matrix M and use solve function + M = dist(X_s, X_t, metric) + + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + return res + + + + diff --git a/ot/utils.py b/ot/utils.py index 4efcb225e..e03884268 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -968,3 +968,93 @@ def citation(self): url = {http://jmlr.org/papers/v22/20-451.html} } """ + + + +############################## WORK IN PROGRESS #################################### + +## Implementation of the OTResultLazy class for ot.solve_sample() with potentials and lazy_plan as attributes + +class OTResultLazy: + def __init__(self, potentials=None, lazy_plan=None, backend=None): + + self._potentials = potentials + self._lazy_plan = lazy_plan + self._backend = backend if backend is not None else NumpyBackend() + + + # Dual potentials -------------------------------------------- + + def __repr__(self): + s = 'OTResultLazy(' + if self._lazy_plan is not None: + s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape) + + if s[-1] != '(': + s = s[:-1] + ')' + else: + s = s + ')' + return s + + @property + def potentials(self): + """Dual potentials, i.e. Lagrange multipliers for the marginal constraints. + + This pair of arrays has the same shape, numerical type + and properties as the input weights "a" and "b". + """ + if self._potentials is not None: + return self._potentials + else: + raise NotImplementedError() + + @property + def potential_a(self): + """First dual potential, associated to the "source" measure "a".""" + if self._potentials is not None: + return self._potentials[0] + else: + raise NotImplementedError() + + @property + def potential_b(self): + """Second dual potential, associated to the "target" measure "b".""" + if self._potentials is not None: + return self._potentials[1] + else: + raise NotImplementedError() + + # Transport plan ------------------------------------------- + @property + def lazy_plan(self): + """A subset of the Transport plan, encoded as a dense array.""" + + if self._lazy_plan is not None: + return self._lazy_plan + else: + raise NotImplementedError() + + @property + def citation(self): + """Appropriate citation(s) for this result, in plain text and BibTex formats.""" + + # The string below refers to the POT library: + # successor methods may concatenate the relevant references + # to the original definitions, solvers and underlying numerical backends. + return """POT library: + + POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Website: https://pythonot.github.io/ + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; + + @article{flamary2021pot, + author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, + title = {{POT}: {Python} {Optimal} {Transport}}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {78}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-451.html} + } + """ \ No newline at end of file diff --git a/test/test_lowrank.py b/test/test_lowrank.py new file mode 100644 index 000000000..6e1f24067 --- /dev/null +++ b/test/test_lowrank.py @@ -0,0 +1,84 @@ +##################################################################################################### +####################################### WORK IN PROGRESS ############################################ +##################################################################################################### + + +""" Test for low rank sinkhorn solvers """ + +import ot +import numpy as np +import pytest +from itertools import product + + +def test_LR_Dykstra(): + # test for LR_Dykstra algorithm ? catch nan values ? + pass + + +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_lowrank_sinkhorn(verbose, warn): + # test low rank sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) + P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) + + Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') + P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) + + # check constraints + np.testing.assert_allclose( + a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + a, P_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + b, P_m.sum(0), atol=1e-05) # metric euclidian + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) + + + +@pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,4))) +def test_lowrank_sinkhorn_alpha_warning(alpha,rank): + # test warning for value of alpha + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, r=rank, alpha=alpha, warn=False) + + + +def test_lowrank_sinkhorn_backends(nx): + # test low rank sinkhorn for different backends + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + Q, R, g = nx.to_numpy(ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, 0.1)) + P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + + np.testing.assert_allclose(a, P.sum(1), atol=1e-05) + np.testing.assert_allclose(b, P.sum(0), atol=1e-05) + + + + diff --git a/test/test_solvers.py b/test/test_solvers.py index f0f5b638f..5a05d54cf 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -255,3 +255,103 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) with pytest.raises(NotImplementedError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) + + + + +########################################################################################################### +############################################ WORK IN PROGRESS ############################################# +########################################################################################################### + +def assert_allclose_sol_sample(sol1, sol2): + # test attributes of OTResultLazy class + lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) + except NotImplementedError: + pass + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_sample(nx): + # test solve_sample when is_Lazy = False + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') + + + +def test_lazy_solve_sample(nx): + # test solve_sample when is_Lazy = True + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) + + # check some attributes + sol.potentials + sol.lazy_plan + + assert_allclose_sol_sample(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) + + assert_allclose_sol_sample(sol, solb) + + # test not implemented reg==0 (or None) + balanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default + + # test not implemented reg==0 (or None) + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default + + # test not implemented reg != 0 + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index 40324518e..a14be460e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -401,3 +401,30 @@ def test_get_coordinate_circle(): x_p = ot.utils.get_coordinate_circle(x) np.testing.assert_allclose(u[0], x_p) + + + +############################################################################################## +##################################### WORK IN PROGRESS ####################################### +############################################################################################## + +# test function for OTResultLazy + +def test_OTResultLazy(): + + res_lazy = ot.utils.OTResultLazy() + + # test print + print(res_lazy) + + # tets get citation + print(res_lazy.citation) + + lst_attributes = ['lazy_plan', + 'potential_a', + 'potential_b', + 'potentials'] + + for at in lst_attributes: + with pytest.raises(NotImplementedError): + getattr(res_lazy, at) \ No newline at end of file