diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 89c5be433..c185e18a7 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -50,7 +50,7 @@ The contributors to this library are: * [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers) -* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn) +* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn, Low rank Gromov-Wasserstein samples) ## Acknowledgments diff --git a/README.md b/README.md index 88dce689a..f1149a008 100644 --- a/README.md +++ b/README.md @@ -357,3 +357,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf). [66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021). + +[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 106042af2..c31081451 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,6 +7,7 @@ + Continuous entropic mapping (PR #613) + New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620) + Add gradient computation with envelope theorem to sinkhorn solver of `ot.solve` with `grad='envelope'` (PR #605). ++ Added support for [Low rank Gromov-Wasserstein](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf) with `ot.gromov.lowrank_gromov_wasserstein_samples` (PR #614) #### Closed issues - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) diff --git a/examples/others/plot_lowrank_GW.py b/examples/others/plot_lowrank_GW.py new file mode 100644 index 000000000..02fef6ded --- /dev/null +++ b/examples/others/plot_lowrank_GW.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +""" +======================================== +Low rank Gromov-Wasterstein between samples +======================================== + +Comparaison between entropic Gromov-Wasserstein and Low Rank Gromov Wasserstein [67] +on two curves in 2D and 3D, both sampled with 200 points. + +The squared Euclidean distance is considered as the ground cost for both samples. + +[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). +"Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs". +In International Conference on Machine Learning (ICML), 2022. +""" + +# Author: Laurène David +# +# License: MIT License +# +# sphinx_gallery_thumbnail_number = 3 + +#%% +import numpy as np +import matplotlib.pylab as pl +import ot.plot +import time + +############################################################################## +# Generate data +# ------------- + +#%% parameters +n_samples = 200 + +# Generate 2D and 3D curves +theta = np.linspace(-4 * np.pi, 4 * np.pi, n_samples) +z = np.linspace(1, 2, n_samples) +r = z**2 + 1 +x = r * np.sin(theta) +y = r * np.cos(theta) + +# Source and target distribution +X = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1) +Y = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1) + + +############################################################################## +# Plot data +# ------------ + +#%% +# Plot the source and target samples +fig = pl.figure(1, figsize=(10, 4)) + +ax = fig.add_subplot(121) +ax.plot(X[:, 0], X[:, 1], color="blue", linewidth=6) +ax.tick_params(left=False, right=False, labelleft=False, + labelbottom=False, bottom=False) +ax.set_title("2D curve (source)") + +ax2 = fig.add_subplot(122, projection="3d") +ax2.plot(Y[:, 0], Y[:, 1], Y[:, 2], c='red', linewidth=6) +ax2.tick_params(left=False, right=False, labelleft=False, + labelbottom=False, bottom=False) +ax2.view_init(15, -50) +ax2.set_title("3D curve (target)") + +pl.tight_layout() +pl.show() + + +############################################################################## +# Entropic Gromov-Wasserstein +# ------------ + +#%% + +# Compute cost matrices +C1 = ot.dist(X, X, metric="sqeuclidean") +C2 = ot.dist(Y, Y, metric="sqeuclidean") + +# Scale cost matrices +r1 = C1.max() +r2 = C2.max() + +C1 = C1 / r1 +C2 = C2 / r2 + + +# Solve entropic gw +reg = 5 * 1e-3 + +start = time.time() +gw, log = ot.gromov.entropic_gromov_wasserstein( + C1, C2, tol=1e-3, epsilon=reg, + log=True, verbose=False) + +end = time.time() +time_entropic = end - start + +entropic_gw_loss = np.round(log['gw_dist'], 3) + +# Plot entropic gw +pl.figure(2) +pl.imshow(gw, interpolation="nearest", aspect="auto") +pl.title("Entropic Gromov-Wasserstein (loss={})".format(entropic_gw_loss)) +pl.show() + + +############################################################################## +# Low rank squared euclidean cost matrices +# ------------ +# %% + +# Compute the low rank sqeuclidean cost decompositions +A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False) +B1, B2 = ot.lowrank.compute_lr_sqeuclidean_matrix(Y, Y, rescale_cost=False) + +# Scale the low rank cost matrices +A1, A2 = A1 / np.sqrt(r1), A2 / np.sqrt(r1) +B1, B2 = B1 / np.sqrt(r2), B2 / np.sqrt(r2) + + +############################################################################## +# Low rank Gromov-Wasserstein +# ------------ +# %% + +# Solve low rank gromov-wasserstein with different ranks +list_rank = [10, 50] +list_P_GW = [] +list_loss_GW = [] +list_time_GW = [] + +for rank in list_rank: + start = time.time() + + Q, R, g, log = ot.lowrank_gromov_wasserstein_samples( + X, Y, reg=0, rank=rank, rescale_cost=False, cost_factorized_Xs=(A1, A2), + cost_factorized_Xt=(B1, B2), seed_init=49, numItermax=1000, log=True, stopThr=1e-6, + ) + end = time.time() + + P = log["lazy_plan"][:] + loss = log["value"] + + list_P_GW.append(P) + list_loss_GW.append(np.round(loss, 3)) + list_time_GW.append(end - start) + + +# %% +# Plot low rank GW with different ranks +pl.figure(3, figsize=(10, 4)) + +pl.subplot(1, 2, 1) +pl.imshow(list_P_GW[0], interpolation="nearest", aspect="auto") +pl.title('Low rank GW (rank=10, loss={})'.format(list_loss_GW[0])) + +pl.subplot(1, 2, 2) +pl.imshow(list_P_GW[1], interpolation="nearest", aspect="auto") +pl.title('Low rank GW (rank=50, loss={})'.format(list_loss_GW[1])) + +pl.tight_layout() +pl.show() + + +# %% +# Compare computation time between entropic GW and low rank GW +print("Entropic GW: {:.2f}s".format(time_entropic)) +print("Low rank GW (rank=10): {:.2f}s".format(list_time_GW[0])) +print("Low rank GW (rank=50): {:.2f}s".format(list_time_GW[1])) diff --git a/ot/__init__.py b/ot/__init__.py index 609f9ff37..d8ac5ac28 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -49,7 +49,8 @@ from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance, sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif) from .gromov import (gromov_wasserstein, gromov_wasserstein2, - gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) + gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2, + lowrank_gromov_wasserstein_samples) from .weak import weak_optimal_transport from .factored import factored_optimal_transport from .solvers import solve, solve_gromov, solve_sample @@ -71,5 +72,5 @@ 'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', - 'lowrank_sinkhorn'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn', + 'lowrank_gromov_wasserstein_samples'] diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index 4d77fc57a..b33dafd32 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -47,6 +47,8 @@ fused_gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing) +from ._lowrank import (_flat_product_operator, lowrank_gromov_wasserstein_samples) + __all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss', 'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed', @@ -64,4 +66,4 @@ 'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein', 'entropic_semirelaxed_fused_gromov_wasserstein2', 'gromov_wasserstein_dictionary_learning', 'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning', - 'fused_gromov_wasserstein_linear_unmixing'] + 'fused_gromov_wasserstein_linear_unmixing', 'lowrank_gromov_wasserstein_samples'] diff --git a/ot/gromov/_lowrank.py b/ot/gromov/_lowrank.py new file mode 100644 index 000000000..5bab15edc --- /dev/null +++ b/ot/gromov/_lowrank.py @@ -0,0 +1,313 @@ +""" +Low rank Gromov-Wasserstein solver +""" + +# Author: Laurène David +# +# License: MIT License + + +import warnings +from ..utils import unif, get_lowrank_lazytensor +from ..backend import get_backend +from ..lowrank import compute_lr_sqeuclidean_matrix, _init_lr_sinkhorn, _LR_Dysktra + + +def _flat_product_operator(X, nx=None): + r""" + Implementation of the flattened out-product operator. + + This function is used in low rank gromov wasserstein to compute the low rank decomposition of + a cost matrix's squared hadamard product (page 6 in paper). + + Parameters + ---------- + X: array-like, shape (n_samples, n_col) + Input matrix for operator + + nx: default None + POT backend + + Returns + ---------- + X_flat: array-like, shape (n_samples, n_col**2) + Matrix with flattened out-product operator applied on each row + + References + ---------- + .. [67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). + "Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs". + In International Conference on Machine Learning (ICML), 2022. + + """ + + if nx is None: + nx = get_backend(X) + + n = X.shape[0] + x1 = X[0, :][:, None] + X_flat = nx.dot(x1, x1.T).flatten()[:, None] + + for i in range(1, n): + x = X[i, :][:, None] + x_out = nx.dot(x, x.T).flatten()[:, None] + X_flat = nx.concatenate((X_flat, x_out), axis=1) + + X_flat = X_flat.T + + return X_flat + + +def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, gamma_init="rescale", + rescale_cost=True, cost_factorized_Xs=None, cost_factorized_Xt=None, stopThr=1e-4, numItermax=1000, + stopThr_dykstra=1e-3, numItermax_dykstra=10000, seed_init=49, warn=True, warn_dykstra=False, log=False): + + r""" + Solve the entropic regularization Gromov-Wasserstein transport problem under low-nonnegative rank constraints + on the couplings and cost matrices. + + Squared euclidean distance matrices are considered for the target and source distributions. + + The function solves the following optimization problem: + + .. math:: + \mathop{\min_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \mathcal{Q}_{A,B}(Q\mathrm{diag}(1/g)R^T) - + \epsilon \cdot H((Q,R,g)) + + where : + + - :math: `A` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the source domain. + - :math: `B` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the target domain. + - :math: `\mathcal{Q}_{A,B}` is quadratic objective function of the Gromov Wasserstein plan. + - :math: `Q` and `R` are the low-rank matrix decomposition of the Gromov-Wasserstein plan. + - :math: `g` is the weight vector for the low-rank decomposition of the Gromov-Wasserstein plan. + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1). + - :math: `r` is the rank of the Gromov-Wasserstein plan. + - :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem. + - :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term. + + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim_Xs) + Samples in the source domain + X_t : array-like, shape (n_samples_b, dim_Xt) + Samples in the target domain + a : array-like, shape (n_samples_a,), optional + Samples weights in the source domain + If let to its default value None, uniform distribution is taken. + b : array-like, shape (n_samples_b,), optional + Samples weights in the target domain + If let to its default value None, uniform distribution is taken. + reg : float, optional + Regularization term >=0 + rank : int, optional. Default is None. (>0) + Nonnegative rank of the OT plan. If None, min(ns, nt) is considered. + alpha : int, optional. Default is 1e-10. (>0 and <1/r) + Lower bound for the weight vector g. + rescale_cost : bool, optional. Default is False + Rescale the low rank factorization of the sqeuclidean cost matrix + seed_init : int, optional. Default is 49. (>0) + Random state for the 'random' initialization of low rank couplings + gamma_init : str, optional. Default is "rescale". + Initialization strategy for gamma. 'rescale', or 'theory' + Gamma is a constant that scales the convergence criterion of the Mirror Descent + optimization scheme used to compute the low-rank couplings (Q, R and g) + numItermax : int, optional. Default is 1000. + Max number of iterations for Low Rank GW + stopThr : float, optional. Default is 1e-4. + Stop threshold on error (>0) for Low Rank GW + The error is the sum of Kullback Divergences computed for each low rank + coupling (Q, R and g) and scaled using gamma. + numItermax_dykstra : int, optional. Default is 2000. + Max number of iterations for the Dykstra algorithm + stopThr_dykstra : float, optional. Default is 1e-7. + Stop threshold on error (>0) in Dykstra + cost_factorized_Xs: tuple, optional. Default is None + Tuple with two pre-computed low rank decompositions (A1, A2) of the source cost + matrix. Both matrices should have a shape of (n_samples_a, dim_Xs + 2). + If None, the low rank cost matrices will be computed as sqeuclidean cost matrices. + cost_factorized_Xt: tuple, optional. Default is None + Tuple with two pre-computed low rank decompositions (B1, B2) of the target cost + matrix. Both matrices should have a shape of (n_samples_b, dim_Xt + 2). + If None, the low rank cost matrices will be computed as sqeuclidean cost matrices. + warn : bool, optional + if True, raises a warning if the low rank GW algorithm doesn't convergence. + warn_dykstra: bool, optional + if True, raises a warning if the Dykstra algorithm doesn't convergence. + log : bool, optional + record log if True + + + Returns + --------- + Q : array-like, shape (n_samples_a, r) + First low-rank matrix decomposition of the OT plan + R: array-like, shape (n_samples_b, r) + Second low-rank matrix decomposition of the OT plan + g : array-like, shape (r, ) + Weight vector for the low-rank decomposition of the OT + log : dict (lazy_plan, value and value_linear) + log dictionary return only if log==True in parameters + + + References + ---------- + .. [67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). + "Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs". + In International Conference on Machine Learning (ICML), 2022. + + """ + + # POT backend + nx = get_backend(X_s, X_t) + ns, nt = X_s.shape[0], X_t.shape[0] + + # Initialize weights a, b + if a is None: + a = unif(ns, type_as=X_s) + if b is None: + b = unif(nt, type_as=X_t) + + # Compute rank (see Section 3.1, def 1) + r = rank + if rank is None: + r = min(ns, nt) + else: + r = min(ns, nt, rank) + + if r <= 0: + raise ValueError("The rank parameter cannot have a negative value") + + # Dykstra won't converge if 1/rank < alpha (see Section 3.2) + if 1 / r < alpha: + raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( + a=alpha, r=1 / rank)) + + if cost_factorized_Xs is not None: + A1, A2 = cost_factorized_Xs + else: + A1, A2 = compute_lr_sqeuclidean_matrix(X_s, X_s, rescale_cost, nx=nx) + + if cost_factorized_Xt is not None: + B1, B2 = cost_factorized_Xt + else: + B1, B2 = compute_lr_sqeuclidean_matrix(X_t, X_t, rescale_cost, nx=nx) + + # Initial values for LR couplings (Q, R, g) with LOT + Q, R, g = _init_lr_sinkhorn( + X_s, X_t, a, b, r, init="random", random_state=seed_init, reg_init=None, nx=nx + ) + + # Gamma initialization + if gamma_init == "theory": + L = (27 * nx.norm(A1) * nx.norm(A2)) / alpha**4 + gamma = 1 / (2 * L) + + if gamma_init not in ["rescale", "theory"]: + raise (NotImplementedError('Not implemented gamma_init="{}"'.format(gamma_init))) + + # initial value of error + err = 1 + + for ii in range(numItermax): + Q_prev = Q + R_prev = R + g_prev = g + + if err > stopThr: + # Compute cost matrices + C1 = nx.dot(A2.T, Q * (1 / g)[None, :]) + C1 = - 4 * nx.dot(A1, C1) + C2 = nx.dot(R.T, B1) + C2 = nx.dot(C2, B2.T) + diag_g = (1 / g)[None, :] + + # Compute C*R dot using the lr decomposition of C + CR = nx.dot(C2, R) + CR = nx.dot(C1, CR) + CR_g = CR * diag_g + + # Compute C.T * Q using the lr decomposition of C + CQ = nx.dot(C1.T, Q) + CQ = nx.dot(C2.T, CQ) + CQ_g = CQ * diag_g + + # Compute omega + omega = nx.diag(nx.dot(Q.T, CR)) + + # Rescale gamma at each iteration + if gamma_init == "rescale": + norm_1 = nx.max(nx.abs(CR_g + reg * nx.log(Q))) ** 2 + norm_2 = nx.max(nx.abs(CQ_g + reg * nx.log(R))) ** 2 + norm_3 = nx.max(nx.abs(-omega * (diag_g**2))) ** 2 + gamma = 10 / max(norm_1, norm_2, norm_3) + + K1 = nx.exp(-gamma * CR_g - ((gamma * reg) - 1) * nx.log(Q)) + K2 = nx.exp(-gamma * CQ_g - ((gamma * reg) - 1) * nx.log(R)) + K3 = nx.exp((gamma * omega / (g**2)) - (gamma * reg - 1) * nx.log(g)) + + # Update couplings with LR Dykstra algorithm + Q, R, g = _LR_Dysktra( + K1, K2, K3, a, b, alpha, stopThr_dykstra, numItermax_dykstra, warn_dykstra, nx + ) + + # Update error with kullback-divergence + err_1 = ((1 / gamma) ** 2) * (nx.kl_div(Q, Q_prev) + nx.kl_div(Q_prev, Q)) + err_2 = ((1 / gamma) ** 2) * (nx.kl_div(R, R_prev) + nx.kl_div(R_prev, R)) + err_3 = ((1 / gamma) ** 2) * (nx.kl_div(g, g_prev) + nx.kl_div(g_prev, g)) + err = err_1 + err_2 + err_3 + + # fix divide by zero + Q = Q + 1e-16 + R = R + 1e-16 + g = g + 1e-16 + + else: + break + + else: + if warn: + warnings.warn( + "Low Rank GW did not converge. You might want to " + "increase the number of iterations `numItermax` " + ) + + # Update low rank costs + C1 = nx.dot(A2.T, Q * (1 / g)[None, :]) + C1 = - 4 * nx.dot(A1, C1) + C2 = nx.dot(R.T, B1) + C2 = nx.dot(C2, B2.T) + + # Compute lazy plan (using LazyTensor class) + lazy_plan = get_lowrank_lazytensor(Q, R, 1 / g) + + # Compute value_quad + A1_, A2_ = _flat_product_operator(A1, nx), _flat_product_operator(A2, nx) + B1_, B2_ = _flat_product_operator(B1, nx), _flat_product_operator(B2, nx) + + x_ = nx.dot(A1_, nx.dot(A2_.T, a)) + y_ = nx.dot(B1_, nx.dot(B2_.T, b)) + c1 = nx.dot(x_, a) + nx.dot(y_, b) + + G = nx.dot(C1, nx.dot(C2, R)) + G = nx.dot(Q.T, G * diag_g) + value_quad = c1 + nx.trace(G) / 2 + + if reg != 0: + reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q + reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g + reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R + value = value_quad + reg * (reg_Q + reg_g + reg_R) + else: + value = value_quad + + if log: + dict_log = dict() + dict_log["value"] = value + dict_log["value_quad"] = value_quad + dict_log["lazy_plan"] = lazy_plan + + return Q, R, g, dict_log + + return Q, R, g diff --git a/test/gromov/test_lowrank.py b/test/gromov/test_lowrank.py new file mode 100644 index 000000000..befc5c835 --- /dev/null +++ b/test/gromov/test_lowrank.py @@ -0,0 +1,125 @@ +""" Tests for gromov._lowrank.py """ + +# Author: Laurène DAVID +# +# License: MIT License + +import ot +import numpy as np +import pytest + + +def test__flat_product_operator(): + # test flat product operator + n, d = 100, 2 + X = np.reshape(1.0 * np.arange(2 * n), (n, d)) + A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False) + + A1_ = ot.gromov._flat_product_operator(A1) + A2_ = ot.gromov._flat_product_operator(A2) + cost = ot.dist(X, X) + + # test value + np.testing.assert_allclose(cost**2, np.dot(A1_, A2_.T), atol=1e-05) + + +def test_lowrank_gromov_wasserstein_samples(): + # test low rank gromov wasserstein + n_samples = 20 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + X_s = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) + X_t = X_s[::-1].copy() + + a = ot.unif(n_samples) + b = ot.unif(n_samples) + + Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples(X_s, X_t, a, b, reg=0.1, log=True, rescale_cost=False) + P = log["lazy_plan"][:] + + # check constraints for P + np.testing.assert_allclose(a, P.sum(1), atol=1e-04) + np.testing.assert_allclose(b, P.sum(0), atol=1e-04) + + # check if lazy_plan is equal to the fully computed plan + P_true = np.dot(Q, np.dot(np.diag(1 / g), R.T)) + np.testing.assert_allclose(P, P_true, atol=1e-05) + + # check warn parameter when low rank GW algorithm doesn't converge + with pytest.warns(UserWarning): + ot.gromov.lowrank_gromov_wasserstein_samples( + X_s, X_t, a, b, reg=0.1, stopThr=0, numItermax=1, warn=True, warn_dykstra=False + ) + + # check warn parameter when Dykstra algorithm doesn't converge + with pytest.warns(UserWarning): + ot.gromov.lowrank_gromov_wasserstein_samples( + X_s, X_t, a, b, reg=0.1, stopThr_dykstra=0, numItermax_dykstra=1, warn=False, warn_dykstra=True + ) + + +@pytest.mark.parametrize(("alpha, rank"), ((0.8, 2), (0.5, 3), (0.2, 6), (0.1, -1))) +def test_lowrank_gromov_wasserstein_samples_alpha_error(alpha, rank): + # Test warning for value of alpha and rank + n_samples = 20 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + X_s = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) + X_t = X_s[::-1].copy() + + a = ot.unif(n_samples) + b = ot.unif(n_samples) + + with pytest.raises(ValueError): + ot.gromov.lowrank_gromov_wasserstein_samples(X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False) + + +@pytest.mark.parametrize(("gamma_init"), ("rescale", "theory", "other")) +def test_lowrank_wasserstein_samples_gamma_init(gamma_init): + # Test lr sinkhorn with different init strategies + n_samples = 20 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + X_s = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) + X_t = X_s[::-1].copy() + + a = ot.unif(n_samples) + b = ot.unif(n_samples) + + if gamma_init not in ["rescale", "theory"]: + with pytest.raises(NotImplementedError): + ot.gromov.lowrank_gromov_wasserstein_samples(X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True) + + else: + Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples(X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True) + P = log["lazy_plan"][:] + + # check constraints for P + np.testing.assert_allclose(a, P.sum(1), atol=1e-04) + np.testing.assert_allclose(b, P.sum(0), atol=1e-04) + + +@pytest.skip_backend('tf') +def test_lowrank_gromov_wasserstein_samples_backends(nx): + # Test low rank sinkhorn for different backends + n_samples = 20 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + X_s = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) + X_t = X_s[::-1].copy() + + a = ot.unif(n_samples) + b = ot.unif(n_samples) + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples(X_sb, X_tb, ab, bb, reg=0.1, log=True) + lazy_plan = log["lazy_plan"] + P = lazy_plan[:] + + np.testing.assert_allclose(ab, P.sum(1), atol=1e-04) + np.testing.assert_allclose(bb, P.sum(0), atol=1e-04)