@@ -43,11 +43,11 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
43
43
44
44
Where :
45
45
46
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
47
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
48
- - :math:`\mathbf{p}`: distribution in the source space
49
- - :math:`\mathbf{q}`: distribution in the target space
50
- - `L`: loss function to account for the misfit between the similarity matrices
46
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space.
47
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space.
48
+ - :math:`\mathbf{p}`: Distribution in the source space.
49
+ - :math:`\mathbf{q}`: Distribution in the target space.
50
+ - `L`: Loss function to account for the misfit between the similarity matrices.
51
51
52
52
.. note:: This function is backend-compatible and will work on arrays
53
53
from all compatible backends. But the algorithm uses the C++ CPU backend
@@ -62,39 +62,39 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
62
62
Parameters
63
63
----------
64
64
C1 : array-like, shape (ns, ns)
65
- Metric cost matrix in the source space
65
+ Metric cost matrix in the source space.
66
66
C2 : array-like, shape (nt, nt)
67
- Metric cost matrix in the target space
67
+ Metric cost matrix in the target space.
68
68
p : array-like, shape (ns,), optional
69
69
Distribution in the source space.
70
70
If let to its default value None, uniform distribution is taken.
71
71
q : array-like, shape (nt,), optional
72
72
Distribution in the target space.
73
73
If let to its default value None, uniform distribution is taken.
74
74
loss_fun : str, optional
75
- loss function used for the solver either 'square_loss' or 'kl_loss'
75
+ Loss function used for the solver either 'square_loss' or 'kl_loss'.
76
76
symmetric : bool, optional
77
77
Either C1 and C2 are to be assumed symmetric or not.
78
78
If let to its default None value, a symmetry test will be conducted.
79
79
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
80
80
verbose : bool, optional
81
- Print information along iterations
81
+ Print information along iterations.
82
82
log : bool, optional
83
- record log if True
83
+ Record log if True.
84
84
armijo : bool, optional
85
- If True the step of the line-search is found via an armijo research . Else closed form is used.
86
- If there are convergence issues use False.
85
+ If True, the step of the line-search is found via an armijo search . Else closed form is used.
86
+ If there are convergence issues, use False.
87
87
G0: array-like, shape (ns,nt), optional
88
- If None the initial transport plan of the solver is pq^T.
88
+ If None, the initial transport plan of the solver is pq^T.
89
89
Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
90
90
max_iter : int, optional
91
- Max number of iterations
91
+ Max number of iterations.
92
92
tol_rel : float, optional
93
- Stop threshold on relative error (>0)
93
+ Stop threshold on relative error (>0).
94
94
tol_abs : float, optional
95
- Stop threshold on absolute error (>0)
95
+ Stop threshold on absolute error (>0).
96
96
**kwargs : dict
97
- parameters can be directly passed to the ot.optim.cg solver
97
+ Parameters can be directly passed to the ot.optim.cg solver.
98
98
99
99
Returns
100
100
-------
@@ -175,7 +175,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
175
175
176
176
if not nx .is_floating_point (C10 ):
177
177
warnings .warn (
178
- "Input structure matrix consists of integer . The transport plan will be "
178
+ "Input structure matrix consists of integers . The transport plan will be "
179
179
"casted accordingly, possibly resulting in a loss of precision. "
180
180
"If this behaviour is unwanted, please make sure your input "
181
181
"structure matrix consists of floating point elements." ,
0 commit comments