diff --git a/README.md b/README.md index f64db8f56..5bc3fe409 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,7 @@ POT provides the following generic OT solvers (links to examples): * [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. * [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] -* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] - formulations). +* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] * [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] diff --git a/RELEASES.md b/RELEASES.md index 745a7de67..b24b85c9c 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,6 +7,7 @@ - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) +- Implement projected gradient descent solvers for entropic partial FGW (PR #702) - Fix documentation in the module `ot.gaussian` (PR #718) #### Closed issues diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py index 865c1e71a..b51b4e1ff 100644 --- a/examples/gromov/plot_barycenter_fgw.py +++ b/examples/gromov/plot_barycenter_fgw.py @@ -91,7 +91,7 @@ def build_noisy_circular_graph( g = nx.Graph() g.add_nodes_from(list(range(N))) for i in range(N): - noise = float(np.random.normal(mu, sigma, 1)) + noise = np.random.normal(mu, sigma, 1)[0] if with_noise: g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise) else: @@ -107,7 +107,7 @@ def build_noisy_circular_graph( if i == N - 1: g.add_edge(i, 1) g.add_edge(N, 0) - noise = float(np.random.normal(mu, sigma, 1)) + noise = np.random.normal(mu, sigma, 1)[0] if with_noise: g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise) else: @@ -157,7 +157,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7): plt.subplot(3, 3, i + 1) g = X0[i] pos = nx.kamada_kawai_layout(g) - nx.draw( + nx.draw_networkx( g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), @@ -173,7 +173,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7): # %% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph # Features distances are the euclidean distances -Cs = [shortest_path(nx.adjacency_matrix(x).todense()) for x in X0] +Cs = [shortest_path(nx.adjacency_matrix(x).toarray()) for x in X0] ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] Ys = [ np.array([v for (k, v) in nx.get_node_attributes(x, "attr_name").items()]).reshape( @@ -199,7 +199,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7): # %% pos = nx.kamada_kawai_layout(bary) -nx.draw( +nx.draw_networkx( bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False ) plt.suptitle("Barycenter", fontsize=20) diff --git a/examples/gromov/plot_partial_fgw.py b/examples/gromov/plot_partial_fgw.py new file mode 100644 index 000000000..cd6976074 --- /dev/null +++ b/examples/gromov/plot_partial_fgw.py @@ -0,0 +1,401 @@ +# -*- coding: utf-8 -*- +r""" +================================= +Plot partial FGW for subgraph matching +================================= + +This example illustrates the computation of partial (Fused) Gromov-Wasserstein +divergences for subgraph matching tasks, using the exact formulation $p(F)GW$ and +the entropically regularized one $p(F)GW_e$ [18, 29]. + +We first create a clean circular graph of 15 nodes with node features correlated with +node positions on the unit circle, and a noisy version where 5 nodes out of the +circle are added. Then knowing the proportion of clean samples in the target graph +$m=3/4$, we show how to identify them using : + - The partial GW matching and its entropic counterpart, omitting node features. + - The partial Fused GW matching and its entropic counterpart. + +[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain +and Courty Nicolas +"Optimal Transport for structured data with application on graphs" +International Conference on Machine Learning (ICML). 2019. + +[29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal +Transport with Applications on Positive-Unlabeled Learning". NeurIPS. +""" + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +# %% load libraries +import numpy as np +import pylab as pl +import networkx as nx +import math +from scipy.sparse.csgraph import shortest_path +import matplotlib.colors as mcol +from matplotlib import cm +from ot.gromov import ( + partial_gromov_wasserstein, + entropic_partial_gromov_wasserstein, + partial_fused_gromov_wasserstein, + entropic_partial_fused_gromov_wasserstein, +) +from ot import unif, dist + +############################################################################## +# Utils for generation and visualization +# ------------- + + +def build_noisy_circular_graph(n_clean=15, n_noise=5, random_seed=0): + """Create a noisy circular graph""" + # create clean circle + np.random.seed(random_seed) + g = nx.Graph() + g.add_nodes_from(np.arange(n_clean + n_noise)) + for i in range(n_clean): + g.add_node(i, weight=math.sin(2 * i * math.pi / n_clean)) + if i == (n_clean - 1): + g.add_edge(i, 0) + else: + g.add_edge(i, i + 1) + # add nodes out of the circle as structure noise + if n_noise > 0: + noisy_nodes = np.random.choice(np.arange(n_clean), n_noise) + for i, j in enumerate(noisy_nodes): + g.add_node(i + n_clean, weight=math.sin(2 * j * math.pi / n_clean)) + g.add_edge(i + n_clean, j) + g.add_edge(i + n_clean, (j + 1) % n_clean) + return g + + +def graph_colors(nx_graph, vmin=0, vmax=7): + cnorm = mcol.Normalize(vmin=vmin, vmax=vmax) + cpick = cm.ScalarMappable(norm=cnorm, cmap="viridis") + cpick.set_array([]) + val_map = {} + for k, v in nx.get_node_attributes(nx_graph, "weight").items(): + val_map[k] = cpick.to_rgba(v) + colors = [] + for node in nx_graph.nodes(): + colors.append(val_map[node]) + return colors + + +def draw_graph( + G, + C, + nodes_color_part, + Gweights=None, + pos=None, + edge_color="black", + node_size=None, + shiftx=0, +): + if pos is None: + pos = nx.kamada_kawai_layout(G) + + if shiftx != 0: + for k, v in pos.items(): + v[0] = v[0] + shiftx + + alpha_edge = 0.7 + width_edge = 1.8 + if Gweights is None: + nx.draw_networkx_edges( + G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color + ) + else: + # We make more visible connections between activated nodes + n = len(Gweights) + edgelist_activated = [] + edgelist_deactivated = [] + for i in range(n): + for j in range(n): + if Gweights[i] * Gweights[j] * C[i, j] > 0: + edgelist_activated.append((i, j)) + elif C[i, j] > 0: + edgelist_deactivated.append((i, j)) + + nx.draw_networkx_edges( + G, + pos, + edgelist=edgelist_activated, + width=width_edge, + alpha=alpha_edge, + edge_color=edge_color, + ) + nx.draw_networkx_edges( + G, + pos, + edgelist=edgelist_deactivated, + width=width_edge, + alpha=0.1, + edge_color=edge_color, + ) + + if Gweights is None: + for node, node_color in enumerate(nodes_color_part): + nx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=node_size, + alpha=1, + node_color=node_color, + ) + else: + scaled_Gweights = Gweights / (0.5 * Gweights.max()) + nodes_size = node_size * scaled_Gweights + for node, node_color in enumerate(nodes_color_part): + if nodes_size[node] == 0: + local_node_size = 0 + else: + local_node_size = max(0.1 * node_size, nodes_size[node]) + nx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=local_node_size, + alpha=1, + node_color=node_color, + ) + return pos + + +def draw_transp_colored( + G1, + C1, + G2, + C2, + p1, + p2, + T, + pos1=None, + pos2=None, + shiftx=4, + switchx=False, + node_size=70, + color_features=False, +): + if color_features: + nodes_color_part1 = graph_colors(G1, vmin=-1, vmax=1) + nodes_color_part2 = graph_colors(G2, vmin=-1, vmax=1) + else: + nodes_color_part1 = C1.shape[0] * ["C0"] + nodes_color_part2 = C2.shape[0] * ["C0"] + + pos1 = draw_graph( + G1, + C1, + nodes_color_part1, + Gweights=p1, + pos=pos1, + node_size=node_size, + shiftx=0, + ) + pos2 = draw_graph( + G2, + C2, + nodes_color_part2, + Gweights=p2, + pos=pos2, + node_size=node_size, + shiftx=shiftx, + ) + T_max = T.max() + for k1, v1 in pos1.items(): + for k2, v2 in pos2.items(): + if T[k1, k2] > 0: + pl.plot( + [pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + "-", + lw=0.8, + alpha=max(0.05, 0.8 * T[k1, k2] / T_max), + color=nodes_color_part1[k1], + ) + return pos1, pos2 + + +############################################################################## +# Generate and visualize data +# ------------- +# We build a clean circular graph that will be matched to a noisy circular graph. + +clean_graph = build_noisy_circular_graph(n_clean=15, n_noise=0) + +noisy_graph = build_noisy_circular_graph(n_clean=15, n_noise=5) + +graphs = [clean_graph, noisy_graph] +list_pos = [] +pl.figure(figsize=(6, 3)) +for i in range(2): + pl.subplot(1, 2, i + 1) + g = graphs[i] + if i == 0: + pl.title("clean graph", fontsize=16) + else: + pl.title("noisy graph", fontsize=16) + pos = nx.kamada_kawai_layout(g) + list_pos.append(pos) + nx.draw_networkx( + g, + pos=pos, + node_color=graph_colors(g, vmin=-1, vmax=1), + with_labels=False, + node_size=100, + ) +pl.show() + +############################################################################## +# Partial (Entropic) Gromov-Wasserstein computation and visualization +# ---------------------- +# Adjacency matrices are compared using both exact and entropic partial GW +# discarding for now node features. +# Then for illustration, the node sizes are proportional to their optimized masses +# and the intensity of the link between two nodes across graphs is set proportionally +# to the corresponding transported mass. + +Cs = [nx.adjacency_matrix(G).toarray().astype(np.float64) for G in graphs] +ps = [unif(C.shape[0]) for C in Cs] + +# provide an informative initialization for better visualization +m = 3.0 / 4.0 +partial_id = np.zeros((15, 20)) +partial_id[:15, :15] = np.eye(15) / 15.0 +G0 = (np.outer(ps[0], ps[1]) + partial_id) * m / 2 + +# compute exact partial GW +T, log = partial_gromov_wasserstein( + Cs[0], Cs[1], ps[0], ps[1], m=m, G0=G0, symmetric=True, log=True +) + +# compute entropic partial GW leading to dense transport plans +Tent, logent = entropic_partial_gromov_wasserstein( + Cs[0], Cs[1], ps[0], ps[1], reg=0.01, m=m, G0=G0, symmetric=True, log=True +) + +# Plot matchings +list_T = [T, Tent] +list_dist = [ + np.round(log["partial_gw_dist"], 3), + np.round(logent["partial_gw_dist"], 3), +] +list_dist_str = ["pGW", "pGW_e"] + +pl.figure(2, figsize=(10, 3)) +pl.clf() +for i in range(2): + pl.subplot(1, 2, i + 1) + pl.axis("off") + pl.title( + r"$%s(\mathbf{C_1},\mathbf{p_1}^\star,\mathbf{C_2},\mathbf{p_2}^\star) =%s$" + % (list_dist_str[i], list_dist[i]), + fontsize=14, + ) + + p2 = list_T[i].sum(0) + + pos1, pos2 = draw_transp_colored( + clean_graph, + Cs[0], + noisy_graph, + Cs[1], + p1=None, + p2=p2, + T=list_T[i], + shiftx=3, + node_size=50, + ) + +pl.tight_layout() +pl.show() + +############################################################################## +# Partial (Entropic) Fused Gromov-Wasserstein computation and visualization +# ---------------------- +# We add now node features compared using pairwise euclidean distance +# to illustrate partial FGW computation with trade-off parameter alpha=0.5 + +Ys = [ + np.array([v for (k, v) in nx.get_node_attributes(G, "weight").items()]).reshape( + -1, 1 + ) + for G in graphs +] +M = dist(Ys[0], Ys[1]) +# provide an informative initialization for better visualization +m = 3.0 / 4.0 +partial_id = np.zeros((15, 20)) +partial_id[:15, :15] = np.eye(15) / 15.0 +G0 = (np.outer(ps[0], ps[1]) + partial_id) * m / 2 + +# compute exact partial GW +T, log = partial_fused_gromov_wasserstein( + M, + Cs[0], + Cs[1], + ps[0], + ps[1], + alpha=0.5, + m=m, + G0=G0, + symmetric=True, + log=True, +) + +# compute entropic partial GW leading to dense transport plans +Tent, logent = entropic_partial_fused_gromov_wasserstein( + M, + Cs[0], + Cs[1], + ps[0], + ps[1], + reg=0.01, + alpha=0.5, + m=m, + G0=G0, + symmetric=True, + log=True, +) + +# Plot matchings +list_T = [T, Tent] +list_dist = [ + np.round(log["partial_fgw_dist"], 3), + np.round(logent["partial_fgw_dist"], 3), +] +list_dist_str = ["pFGW", "pFGW_e"] + +pl.figure(3, figsize=(10, 3)) +pl.clf() +for i in range(2): + pl.subplot(1, 2, i + 1) + pl.axis("off") + pl.title( + r"$%s(\mathbf{C_1},\mathbf{p_1}^\star,\mathbf{C_2}, \mathbf{p_2}^\star) =%s$" + % (list_dist_str[i], list_dist[i]), + fontsize=14, + ) + + p2 = list_T[i].sum(0) + pos1, pos2 = draw_transp_colored( + clean_graph, + Cs[0], + noisy_graph, + Cs[1], + p1=None, + p2=p2, + T=list_T[i], + shiftx=3, + node_size=50, + color_features=True, + ) + +pl.tight_layout() +pl.show() diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py index 5ccc197d6..23a5f96a2 100755 --- a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py +++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py @@ -5,7 +5,10 @@ ================================================== This example is designed to show how to use the Partial (Gromov-)Wasserstein -distance computation in POT. +distance computation in POT [29]. + +[29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal +Transport with Applications on Positive-Unlabeled Learning". NeurIPS. """ # Author: Laetitia Chapel diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index b7520e099..4c470b8de 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -107,6 +107,8 @@ solve_partial_gromov_linesearch, entropic_partial_gromov_wasserstein, entropic_partial_gromov_wasserstein2, + entropic_partial_fused_gromov_wasserstein, + entropic_partial_fused_gromov_wasserstein2, ) @@ -180,4 +182,6 @@ "solve_partial_gromov_linesearch", "entropic_partial_gromov_wasserstein", "entropic_partial_gromov_wasserstein2", + "entropic_partial_fused_gromov_wasserstein", + "entropic_partial_fused_gromov_wasserstein2", ] diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py index b08f60174..d15075c02 100644 --- a/ot/gromov/_partial.py +++ b/ot/gromov/_partial.py @@ -45,7 +45,7 @@ def partial_gromov_wasserstein( .. math:: \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} @@ -332,7 +332,7 @@ def partial_gromov_wasserstein2( .. math:: \mathbf{PGW} = \mathop{\min}_\mathbf{T} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} @@ -524,7 +524,7 @@ def partial_fused_gromov_wasserstein( .. math:: \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} @@ -812,7 +812,7 @@ def partial_fused_gromov_wasserstein2( .. math:: \mathbf{PFGW}_{\alpha} = \mathop{\min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} @@ -1088,18 +1088,18 @@ def entropic_partial_gromov_wasserstein( The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_{\gamma} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + \mathbf{T} = \mathop{\arg \min}_{\mathbf{T}} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) + T_{i,j} T_{k,l} + \mathrm{reg} \Omega(\mathbf{T}) .. math:: - s.t. \ \gamma &\geq 0 + s.t. \ \mathbf{T} &\geq 0 - \gamma \mathbf{1} &\leq \mathbf{a} + \mathbf{T} \mathbf{1} &\leq \mathbf{a} - \gamma^T \mathbf{1} &\leq \mathbf{b} + \mathbf{T}^T \mathbf{1} &\leq \mathbf{b} - \mathbf{1}^T \gamma^T \mathbf{1} = m + \mathbf{1}^T \mathbf{T}^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : @@ -1109,7 +1109,7 @@ def entropic_partial_gromov_wasserstein( - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights - `L`: quadratic loss function - :math:`\Omega` is the entropic regularization term, - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\mathbf{T})=\sum_{i,j} T_{i,j}\log(T_{i,j})` - `m` is the amount of mass to be transported The formulation of the GW problem has been proposed in @@ -1173,7 +1173,7 @@ def entropic_partial_gromov_wasserstein( Returns ------- - :math:`gamma` : ndarray, shape (dim_a, dim_b) + T : ndarray, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -1327,18 +1327,18 @@ def entropic_partial_gromov_wasserstein2( The function solves the following optimization problem: .. math:: - PGW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, - \mathbf{C_2}_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + PGW = \min_{\mathbf{T}} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, + \mathbf{C_2}_{j,l}) + T_{i,j}T_{k,l} + \mathrm{reg} \Omega(\mathbf{T}) .. math:: - s.t. \ \gamma &\geq 0 + s.t. \ \mathbf{T} &\geq 0 - \gamma \mathbf{1} &\leq \mathbf{a} + \mathbf{T} \mathbf{1} &\leq \mathbf{a} - \gamma^T \mathbf{1} &\leq \mathbf{b} + \mathbf{T}^T \mathbf{1} &\leq \mathbf{b} - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + \mathbf{1}^T \mathbf{T}^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : @@ -1347,7 +1347,7 @@ def entropic_partial_gromov_wasserstein2( - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights - `L`: Loss function to account for the misfit between the similarity matrices. - :math:`\Omega` is the entropic regularization term, - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\mathbf{T})=\sum_{i,j} T_{i,j}\log(T_{i,j})` - `m` is the amount of mass to be transported The formulation of the GW problem has been proposed in @@ -1433,3 +1433,380 @@ def entropic_partial_gromov_wasserstein2( return log_gw["partial_gw_dist"], log_gw else: return log_gw["partial_gw_dist"] + + +def entropic_partial_fused_gromov_wasserstein( + M, + C1, + C2, + p=None, + q=None, + reg=1.0, + m=None, + loss_fun="square_loss", + alpha=0.5, + G0=None, + numItermax=1000, + tol=1e-7, + symmetric=None, + log=False, + verbose=False, +): + r""" + Returns the entropic partial Fused Gromov-Wasserstein transport between + :math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` and + :math:`(\mathbf{C_2}, \mathbf{F_2}, \mathbf{q})`, with pairwise + distance matrix :math:`\mathbf{M}` between node feature matrices. + + The function solves the following optimization problem: + + .. math:: + \mathbf{T} = \mathop{\arg \min}_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) + T_{i,j} T_{k,l} + \mathrm{reg} \Omega(\mathbf{T}) + + .. math:: + s.t. \ \mathbf{T} &\geq 0 + + \mathbf{T} \mathbf{1} &\leq \mathbf{a} + + \mathbf{T}^T \mathbf{1} &\leq \mathbf{b} + + \mathbf{1}^T \mathbf{T}^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + + where : + + - :math:`\mathbf{M}`: metric cost matrix between features across domains + - :math:`\mathbf{C_1}` is the metric cost matrix in the source space + - :math:`\mathbf{C_2}` is the metric cost matrix in the target space + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights + - `L`: quadratic loss function + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega(\mathbf{T})=\sum_{i,j} T_{i,j}\log(T_{i,j})` + - `m` is the amount of mass to be transported + + The formulation of the FGW problem has been proposed in + :ref:`[24] ` and the + partial GW in :ref:`[29] ` + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + reg: float, optional. Default is 1. + entropic regularization parameter + m : float, optional + Amount of mass to be transported (default: + :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + loss_fun : str, optional + Loss function used for the solver either 'square_loss' or 'kl_loss'. + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0 : array-like, shape (ns, nt), optional + Initialization of the transportation matrix + numItermax : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + log : bool, optional + return log if True + verbose : bool, optional + Print information along iterations + + Returns + ------- + T : ndarray, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + + + .. _references-entropic-partial-fused-gromov-wasserstein: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. + + See Also + -------- + ot.gromov.partial_fused_gromov_wasserstein: exact Partial Fused Gromov-Wasserstein + """ + + arr = [M, C1, C2, G0] + if p is not None: + p = list_to_array(p) + arr.append(p) + if q is not None: + q = list_to_array(q) + arr.append(q) + + nx = get_backend(*arr) + + if p is None: + p = nx.ones(C1.shape[0], type_as=C1) / C1.shape[0] + if q is None: + q = nx.ones(C2.shape[0], type_as=C2) / C2.shape[0] + + if m is None: + m = min(nx.sum(p), nx.sum(q)) + elif m < 0: + raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + elif m > min(nx.sum(p), nx.sum(q)): + raise ValueError( + "Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1)." + ) + + if G0 is None: + G0 = ( + nx.outer(p, q) * m / (nx.sum(p) * nx.sum(q)) + ) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + + else: + # Check marginals of G0 + assert nx.any(nx.sum(G0, 1) <= p) + assert nx.any(nx.sum(G0, 0) <= q) + + if symmetric is None: + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose( + C2, C2.T, atol=1e-10 + ) + + # Setup gradient computation + fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun, nx) + fC2t = fC2.T + if not symmetric: + fC1t, hC1t, hC2t = fC1.T, hC1.T, hC2.T + + ones_p = nx.ones(p.shape[0], type_as=p) + ones_q = nx.ones(q.shape[0], type_as=q) + + def f(G): + pG = nx.sum(G, 1) + qG = nx.sum(G, 0) + constC1 = nx.outer(nx.dot(fC1, pG), ones_q) + constC2 = nx.outer(ones_p, nx.dot(qG, fC2t)) + return alpha * gwloss(constC1 + constC2, hC1, hC2, G, nx) + ( + 1 - alpha + ) * nx.sum(G * M) + + if symmetric: + + def df(G): + pG = nx.sum(G, 1) + qG = nx.sum(G, 0) + constC1 = nx.outer(nx.dot(fC1, pG), ones_q) + constC2 = nx.outer(ones_p, nx.dot(qG, fC2t)) + return alpha * gwggrad(constC1 + constC2, hC1, hC2, G, nx) + ( + 1 - alpha + ) * nx.sum(G * M) + else: + + def df(G): + pG = nx.sum(G, 1) + qG = nx.sum(G, 0) + constC1 = nx.outer(nx.dot(fC1, pG), ones_q) + constC2 = nx.outer(ones_p, nx.dot(qG, fC2t)) + constC1t = nx.outer(nx.dot(fC1t, pG), ones_q) + constC2t = nx.outer(ones_p, nx.dot(qG, fC2)) + + return 0.5 * alpha * ( + gwggrad(constC1 + constC2, hC1, hC2, G, nx) + + gwggrad(constC1t + constC2t, hC1t, hC2t, G, nx) + ) + (1 - alpha) * nx.sum(G * M) + + cpt = 0 + err = 1 + + loge = {"err": []} + + while err > tol and cpt < numItermax: + Gprev = G0 + M_entr = df(G0) + G0 = entropic_partial_wasserstein(p, q, M_entr, reg, m) + if cpt % 10 == 0: # to speed up the computations + err = np.linalg.norm(G0 - Gprev) + if log: + loge["err"].append(err) + if verbose: + if cpt % 200 == 0: + print( + "{:5s}|{:12s}|{:12s}".format("It.", "Err", "Loss") + + "\n" + + "-" * 31 + ) + print("{:5d}|{:8e}|{:8e}".format(cpt, err, f(G0))) + + cpt += 1 + + if log: + loge["partial_fgw_dist"] = f(G0) + return G0, loge + else: + return G0 + + +def entropic_partial_fused_gromov_wasserstein2( + M, + C1, + C2, + p=None, + q=None, + reg=1.0, + m=None, + loss_fun="square_loss", + alpha=0.5, + G0=None, + numItermax=1000, + tol=1e-7, + symmetric=None, + log=False, + verbose=False, +): + r""" + Returns the entropic partial Fused Gromov-Wasserstein discrepancy between + :math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` and + :math:`(\mathbf{C_2}, \mathbf{F_2}, \mathbf{q})`, with pairwise + distance matrix :math:`\mathbf{M}` between node feature matrices. + + The function solves the following optimization problem: + + .. math:: + PGW = \min_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} + + \mathrm{reg} \cdot\Omega(\mathbf{T}) + + .. math:: + s.t. \ \mathbf{T} &\geq 0 + + \mathbf{T} \mathbf{1} &\leq \mathbf{a} + + \mathbf{T}^T \mathbf{1} &\leq \mathbf{b} + + \mathbf{1}^T \mathbf{T}^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + + where : + + - :math:`\mathbf{M}`: metric cost matrix between features across domains + - :math:`\mathbf{C_1}` is the metric cost matrix in the source space + - :math:`\mathbf{C_2}` is the metric cost matrix in the target space + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights + - `L`: Loss function to account for the misfit between the similarity matrices. + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega(\mathbf{T})=\sum_{i,j} T_{i,j}\log(T_{i,j})` + - `m` is the amount of mass to be transported + + The formulation of the FGW problem has been proposed in + :ref:`[24] ` and the + partial GW in :ref:`[29] ` + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + reg: float + entropic regularization parameter + m : float, optional + Amount of mass to be transported (default: + :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + loss_fun : str, optional + Loss function used for the solver either 'square_loss' or 'kl_loss'. + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0 : ndarray, shape (ns, nt), optional + Initialization of the transportation matrix + numItermax : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + log : bool, optional + return log if True + verbose : bool, optional + Print information along iterations + + + Returns + ------- + partial_fgw_dist: float + Partial Entropic Fused Gromov-Wasserstein discrepancy + log : dict + log dictionary returned only if `log` is `True` + + .. _references-entropic-partial-fused-gromov-wasserstein2: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. + """ + nx = get_backend(M, C1, C2) + + T, log_pfgw = entropic_partial_fused_gromov_wasserstein( + M, + C1, + C2, + p, + q, + reg, + m, + loss_fun, + alpha, + G0, + numItermax, + tol, + symmetric, + True, + verbose, + ) + + log_pfgw["T"] = T + + # setup for ot.solve_gromov + lin_term = nx.sum(T * M) + log_pfgw["quad_loss"] = log_pfgw["partial_fgw_dist"] - (1 - alpha) * lin_term + log_pfgw["lin_loss"] = lin_term * (1 - alpha) + + if log: + return log_pfgw["partial_fgw_dist"], log_pfgw + else: + return log_pfgw["partial_fgw_dist"] diff --git a/ot/solvers.py b/ot/solvers.py index ff513ed0d..a5bbf0e94 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -25,12 +25,16 @@ partial_gromov_wasserstein2, partial_fused_gromov_wasserstein2, entropic_partial_gromov_wasserstein2, + entropic_partial_fused_gromov_wasserstein2, ) from .gaussian import empirical_bures_wasserstein_distance from .factored import factored_optimal_transport from .lowrank import lowrank_sinkhorn from .optim import cg +import warnings + + lst_method_lazy = [ "1d", "gaussian", @@ -657,7 +661,8 @@ def solve_gromov( ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None - (balanced OT), Not implemented yet + (balanced OT). Not implemented yet for "KL" unbalanced penalization + function :math:`U`. Corresponds to the total transport mass for partial OT. unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed", "partial", by default "KL" but note that it is not implemented yet. @@ -862,8 +867,15 @@ def solve_gromov( if reg is None or reg == 0: # exact OT if unbalanced is None and unbalanced_type.lower() not in [ - "semirelaxed" + "semirelaxed", ]: # Exact balanced OT + if unbalanced_type.lower() in ["partial"]: + warnings.warn( + "Exact balanced OT is computed as `unbalanced=None` even though " + f"unbalanced_type = {unbalanced_type}.", + stacklevel=2, + ) + if M is None or alpha == 1: # Gromov-Wasserstein problem # default values for solver if max_iter is None: @@ -999,9 +1011,11 @@ def solve_gromov( # potentials = (log['u'], log['v']) TODO elif unbalanced_type.lower() in ["partial"]: # Partial OT - if M is None: # Partial Gromov-Wasserstein problem + if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError("Partial GW mass given in reg is too large")) + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) # default values for solver if max_iter is None: @@ -1030,8 +1044,9 @@ def solve_gromov( else: # partial FGW if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError("Partial FGW mass given in reg is too large")) - + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) # default values for solver if max_iter is None: max_iter = 1000 @@ -1072,8 +1087,15 @@ def solve_gromov( else: # regularized OT if unbalanced is None and unbalanced_type.lower() not in [ - "semirelaxed" + "semirelaxed", ]: # Balanced regularized OT + if unbalanced_type.lower() in ["partial"]: + warnings.warn( + "Exact balanced OT is computed as `unbalanced=None` even though " + f"unbalanced_type = {unbalanced_type}.", + stacklevel=2, + ) + if reg_type.lower() in ["entropy"] and ( M is None or alpha == 1 ): # Entropic Gromov-Wasserstein problem @@ -1229,9 +1251,11 @@ def solve_gromov( value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) elif unbalanced_type.lower() in ["partial"]: # Partial OT - if M is None: # Partial Gromov-Wasserstein problem + if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError("Partial GW mass given in reg is too large")) + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) # default values for solver if max_iter is None: @@ -1239,7 +1263,7 @@ def solve_gromov( if tol is None: tol = 1e-7 - value_quad, log = entropic_partial_gromov_wasserstein2( + value_noreg, log = entropic_partial_gromov_wasserstein2( Ca, Cb, a, @@ -1255,12 +1279,45 @@ def solve_gromov( verbose=verbose, ) - value_quad = value + value_quad = value_noreg plan = log["T"] # potentials = (log['u'], log['v']) TODO - + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: # partial FGW - raise (NotImplementedError("Partial entropic FGW not implemented yet")) + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value_noreg, log = entropic_partial_fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + b, + reg=reg, + loss_fun=loss_fun, + alpha=alpha, + m=unbalanced, + log=True, + numItermax=max_iter, + G0=plan_init, + tol=tol, + symmetric=symmetric, + verbose=verbose, + ) + + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] + # potentials = (log['u'], log['v']) TODO + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: # unbalanced AND regularized OT raise ( diff --git a/test/gromov/test_partial.py b/test/gromov/test_partial.py index 8b6205683..3b133242c 100644 --- a/test/gromov/test_partial.py +++ b/test/gromov/test_partial.py @@ -49,6 +49,16 @@ def test_raise_errors(): with pytest.raises(ValueError): ot.gromov.partial_fused_gromov_wasserstein(M, M, M, p, q, m=-1, log=True) + with pytest.raises(ValueError): + ot.gromov.entropic_partial_fused_gromov_wasserstein( + M, M, M, p, q, m=2, log=True + ) + + with pytest.raises(ValueError): + ot.gromov.entropic_partial_fused_gromov_wasserstein( + M, M, M, p, q, m=-1, log=True + ) + def test_partial_gromov_wasserstein(nx): rng = np.random.RandomState(42) @@ -585,3 +595,156 @@ def test_entropic_partial_gromov_wasserstein(nx): C1b, C2b, p=pb, q=None, reg=1e4, m=m, loss_fun=loss_fun, log=False ) np.testing.assert_allclose(w0, w0_val, rtol=1e-8) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_partial_fused_gromov_wasserstein(nx): + rng = np.random.RandomState(42) + n_samples = 20 # nb samples + n_noise = 10 # nb of samples (noise) + + p = ot.unif(n_samples + n_noise) + psub = ot.unif(n_samples - 5 + n_noise) + q = ot.unif(n_samples + n_noise) + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + mu_t = np.array([0, 0, 0]) + cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + # clean samples + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=rng) + P = sp.linalg.sqrtm(cov_t) + xt = rng.randn(n_samples, 3).dot(P) + mu_t + # add noise + xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0) + xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0) + xt2 = xs[::-1].copy() + + C1 = ot.dist(xs, xs) + F1 = xs + + C1sub = ot.dist(xs[5:], xs[5:]) + F1sub = xs[5:] + + C2 = ot.dist(xt, xt) + F2 = xs + + C3 = ot.dist(xt2, xt2) + F3 = xt2 + + M11sub = ot.dist(F1, F1sub) + M12 = ot.dist(F1, F2) + M13 = ot.dist(F1, F3) + + m = 2.0 / 3.0 + + M11subb, M12b, M13b, C1b, C1subb, C2b, C3b, pb, psubb, qb = nx.from_numpy( + M11sub, M12, M13, C1, C1sub, C2, C3, p, psub, q + ) + + G0 = ( + np.outer(p, q) * m / (np.sum(p) * np.sum(q)) + ) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + G0b = nx.from_numpy(G0) + + # check consistency across backends and stability w.r.t loss/marginals/sym + list_sym = [True, None] + for i, loss_fun in enumerate(["square_loss", "kl_loss"]): + res, log = ot.gromov.entropic_partial_fused_gromov_wasserstein( + M13, + C1, + C3, + p=p, + q=None, + reg=1e4, + m=m, + loss_fun=loss_fun, + G0=None, + log=True, + symmetric=list_sym[i], + verbose=True, + ) + + resb, logb = ot.gromov.entropic_partial_fused_gromov_wasserstein( + M13b, + C1b, + C3b, + p=None, + q=qb, + reg=1e4, + m=m, + loss_fun=loss_fun, + G0=G0b, + log=True, + symmetric=False, + verbose=True, + ) + + resb_ = nx.to_numpy(resb) + try: # some instability can occur with kl. to investigate further. + np.testing.assert_allclose(res, resb_, rtol=1e-4) + except AssertionError: + pass + + assert np.all(res.sum(1) <= p) # cf convergence wasserstein + assert np.all(res.sum(0) <= q) # cf convergence wasserstein + + # tests with m is None + res = ot.gromov.entropic_partial_fused_gromov_wasserstein( + M13, + C1, + C3, + p=p, + q=None, + reg=1e4, + G0=None, + log=False, + symmetric=list_sym[i], + verbose=True, + ) + + resb = ot.gromov.entropic_partial_fused_gromov_wasserstein( + M13b, + C1b, + C3b, + p=None, + q=qb, + reg=1e4, + G0=None, + log=False, + symmetric=False, + verbose=True, + ) + + resb_ = nx.to_numpy(resb) + np.testing.assert_allclose(res, resb_, rtol=1e-4) + np.testing.assert_allclose(np.sum(res), 1.0, rtol=1e-4) + + # tests with different number of samples across spaces + m = 0.5 + res, log = ot.gromov.entropic_partial_fused_gromov_wasserstein( + M11sub, C1, C1sub, p=p, q=psub, reg=1e4, m=m, log=True + ) + + resb, logb = ot.gromov.entropic_partial_fused_gromov_wasserstein( + M11subb, C1b, C1subb, p=pb, q=psubb, reg=1e4, m=m, log=True + ) + + resb_ = nx.to_numpy(resb) + np.testing.assert_allclose(res, resb_, rtol=1e-4) + assert np.all(res.sum(1) <= p) # cf convergence wasserstein + assert np.all(res.sum(0) <= psub) # cf convergence wasserstein + np.testing.assert_allclose(np.sum(res), m, rtol=1e-4) + + # tests for pGW2 + for loss_fun in ["square_loss", "kl_loss"]: + w0, log0 = ot.gromov.entropic_partial_fused_gromov_wasserstein2( + M12, C1, C2, p=None, q=q, reg=1e4, m=m, loss_fun=loss_fun, log=True + ) + w0_val = ot.gromov.entropic_partial_fused_gromov_wasserstein2( + M12b, C1b, C2b, p=pb, q=None, reg=1e4, m=m, loss_fun=loss_fun, log=False + ) + np.testing.assert_allclose(w0, w0_val, rtol=1e-8) diff --git a/test/test_solvers.py b/test/test_solvers.py index 85852aca6..a0c1d7c43 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -518,6 +518,10 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, unbalanced_type="partial", unbalanced=1.5) with pytest.raises(ValueError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type="partial", unbalanced=1.5) + with pytest.raises(ValueError): + ot.solve_gromov(Ca, Cb, M, unbalanced_type="partial", unbalanced=1.5) + with pytest.raises(ValueError): + ot.solve_gromov(Ca, Cb, M, reg=1, unbalanced_type="partial", unbalanced=1.5) def test_solve_sample(nx):