Skip to content

Commit 2a74966

Browse files
committed
Unify docstring for other functions as well
1 parent db76ee1 commit 2a74966

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

ot/gromov/_bregman.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def entropic_gromov_wasserstein(
8181
q : array-like, shape (nt,), optional
8282
Distribution in the target space.
8383
If let to its default value None, uniform distribution is taken.
84-
loss_fun : string, optional (default='square_loss')
84+
loss_fun : string, optional (default='square_loss')
8585
Loss function used for the solver either 'square_loss' or 'kl_loss'
8686
epsilon : float, optional
8787
Regularization term >0
@@ -92,7 +92,7 @@ def entropic_gromov_wasserstein(
9292
G0: array-like, shape (ns,nt), optional
9393
If None the initial transport plan of the solver is pq^T.
9494
Otherwise G0 will be used as initial transport of the solver. G0 is not
95-
required to satisfy marginal constraints but we strongly recommand it
95+
required to satisfy marginal constraints but we strongly recommend it
9696
to correctly estimate the GW distance.
9797
max_iter : int, optional
9898
Max number of iterations
@@ -283,7 +283,7 @@ def entropic_gromov_wasserstein2(
283283
q : array-like, shape (nt,), optional
284284
Distribution in the target space.
285285
If let to its default value None, uniform distribution is taken.
286-
loss_fun : string, optional
286+
loss_fun : string, optional (default='square_loss')
287287
Loss function used for the solver either 'square_loss' or 'kl_loss'
288288
epsilon : float, optional
289289
Regularization term >0
@@ -376,7 +376,7 @@ def entropic_gromov_barycenters(
376376
lambdas : list of float, optional
377377
List of the `S` spaces' weights.
378378
If let to its default value None, uniform weights are taken.
379-
loss_fun : string, optional (default='square_loss')
379+
loss_fun : string, optional (default='square_loss')
380380
Loss function used for the solver either 'square_loss' or 'kl_loss'
381381
epsilon : float, optional
382382
Regularization term >0
@@ -555,21 +555,21 @@ def entropic_fused_gromov_wasserstein(
555555
q : array-like, shape (nt,), optional
556556
Distribution in the target space.
557557
If let to its default value None, uniform distribution is taken.
558-
loss_fun : string, optional
558+
loss_fun : string, optional (default='square_loss')
559559
Loss function used for the solver either 'square_loss' or 'kl_loss'
560560
epsilon : float, optional
561561
Regularization term >0
562562
symmetric : bool, optional
563563
Either C1 and C2 are to be assumed symmetric or not.
564564
If let to its default None value, a symmetry test will be conducted.
565-
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
565+
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
566566
alpha : float, optional
567567
Trade-off parameter (0 < alpha < 1)
568568
G0: array-like, shape (ns,nt), optional
569569
If None the initial transport plan of the solver is pq^T.
570570
Otherwise G0 will be used as initial transport of the solver. G0 is not
571-
required to satisfy marginal constraints but we strongly recommand it
572-
to correcly estimate the GW distance.
571+
required to satisfy marginal constraints but we strongly recommend it
572+
to correctly estimate the GW distance.
573573
max_iter : int, optional
574574
Max number of iterations
575575
tol : float, optional
@@ -616,6 +616,9 @@ def entropic_fused_gromov_wasserstein(
616616
if solver not in ['PGD', 'PPA']:
617617
raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver)
618618

619+
if loss_fun not in ('square_loss', 'kl_loss'):
620+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
621+
619622
M, C1, C2 = list_to_array(M, C1, C2)
620623
arr = [M, C1, C2]
621624
if p is not None:
@@ -767,7 +770,7 @@ def entropic_fused_gromov_wasserstein2(
767770
q : array-like, shape (nt,), optional
768771
Distribution in the target space.
769772
If let to its default value None, uniform distribution is taken.
770-
loss_fun : string, optional
773+
loss_fun : string, optional (default='square_loss')
771774
Loss function used for the solver either 'square_loss' or 'kl_loss'
772775
epsilon : float, optional
773776
Regularization term >0
@@ -780,8 +783,8 @@ def entropic_fused_gromov_wasserstein2(
780783
G0: array-like, shape (ns,nt), optional
781784
If None the initial transport plan of the solver is pq^T.
782785
Otherwise G0 will be used as initial transport of the solver. G0 is not
783-
required to satisfy marginal constraints but we strongly recommand it
784-
to correcly estimate the GW distance.
786+
required to satisfy marginal constraints but we strongly recommend it
787+
to correctly estimate the GW distance.
785788
max_iter : int, optional
786789
Max number of iterations
787790
tol : float, optional
@@ -862,8 +865,8 @@ def entropic_fused_gromov_barycenters(
862865
lambdas : list of float, optional
863866
List of the `S` spaces' weights.
864867
If let to its default value None, uniform weights are taken.
865-
loss_fun : callable, optional
866-
tensor-matrix multiplication function based on specific loss function
868+
loss_fun : string, optional (default='square_loss')
869+
Loss function used for the solver either 'square_loss' or 'kl_loss'
867870
epsilon : float, optional
868871
Regularization term >0
869872
symmetric : bool, optional.
@@ -912,6 +915,9 @@ def entropic_fused_gromov_barycenters(
912915
"Optimal Transport for structured data with application on graphs"
913916
International Conference on Machine Learning (ICML). 2019.
914917
"""
918+
if loss_fun not in ('square_loss', 'kl_loss'):
919+
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
920+
915921
Cs = list_to_array(*Cs)
916922
Ys = list_to_array(*Ys)
917923
arr = [*Cs, *Ys]
@@ -982,7 +988,6 @@ def entropic_fused_gromov_barycenters(
982988

983989
if loss_fun == 'square_loss':
984990
C = update_square_loss(p, lambdas, T, Cs)
985-
986991
elif loss_fun == 'kl_loss':
987992
C = update_kl_loss(p, lambdas, T, Cs)
988993

0 commit comments

Comments
 (0)