@@ -459,6 +459,12 @@ def test_entropic_proximal_gromov(nx):
459
459
460
460
C1b , C2b , pb , qb , G0b = nx .from_numpy (C1 , C2 , p , q , G0 )
461
461
462
+ with pytest .raises (ValueError ):
463
+ loss_fun = 'weird_loss_fun'
464
+ G , log = ot .gromov .entropic_gromov_wasserstein (
465
+ C1 , C2 , None , q , loss_fun , symmetric = None , G0 = G0 ,
466
+ epsilon = 1e-1 , max_iter = 50 , solver = 'PPA' , verbose = True , log = True , numItermax = 1 )
467
+
462
468
G , log = ot .gromov .entropic_gromov_wasserstein (
463
469
C1 , C2 , None , q , 'square_loss' , symmetric = None , G0 = G0 ,
464
470
epsilon = 1e-1 , max_iter = 50 , solver = 'PPA' , verbose = True , log = True , numItermax = 1 )
@@ -606,6 +612,12 @@ def test_entropic_fgw(nx):
606
612
607
613
Mb , C1b , C2b , pb , qb , G0b = nx .from_numpy (M , C1 , C2 , p , q , G0 )
608
614
615
+ with pytest .raises (ValueError ):
616
+ loss_fun = 'weird_loss_fun'
617
+ G , log = ot .gromov .entropic_fused_gromov_wasserstein (
618
+ M , C1 , C2 , None , None , loss_fun , symmetric = None , G0 = G0 ,
619
+ epsilon = 1e-1 , max_iter = 10 , verbose = True , log = True )
620
+
609
621
G , log = ot .gromov .entropic_fused_gromov_wasserstein (
610
622
M , C1 , C2 , None , None , 'square_loss' , symmetric = None , G0 = G0 ,
611
623
epsilon = 1e-1 , max_iter = 10 , verbose = True , log = True )
@@ -812,20 +824,28 @@ def test_entropic_fgw_barycenter(nx):
812
824
C2 = ot .dist (Xt )
813
825
p1 = ot .unif (ns )
814
826
p2 = ot .unif (nt )
815
- n_samples = 2
827
+ n_samples = 3
816
828
p = ot .unif (n_samples )
817
829
818
830
ysb , ytb , C1b , C2b , p1b , p2b , pb = nx .from_numpy (ys , yt , C1 , C2 , p1 , p2 , p )
819
831
832
+ with pytest .raises (ValueError ):
833
+ loss_fun = 'weird_loss_fun'
834
+ X , C , log = ot .gromov .entropic_fused_gromov_barycenters (
835
+ n_samples , [ys , yt ], [C1 , C2 ], None , p , [.5 , .5 ], loss_fun , 0.1 ,
836
+ max_iter = 10 , tol = 1e-3 , verbose = True , warmstartT = True , random_state = 42 ,
837
+ solver = 'PPA' , numItermax = 10 , log = True
838
+ )
839
+
820
840
X , C , log = ot .gromov .entropic_fused_gromov_barycenters (
821
841
n_samples , [ys , yt ], [C1 , C2 ], None , p , [.5 , .5 ], 'square_loss' , 0.1 ,
822
842
max_iter = 10 , tol = 1e-3 , verbose = True , warmstartT = True , random_state = 42 ,
823
- solver = 'PPA' , numItermax = 1 , log = True
843
+ solver = 'PPA' , numItermax = 10 , log = True
824
844
)
825
845
Xb , Cb = ot .gromov .entropic_fused_gromov_barycenters (
826
846
n_samples , [ysb , ytb ], [C1b , C2b ], [p1b , p2b ], None , [.5 , .5 ], 'square_loss' , 0.1 ,
827
847
max_iter = 10 , tol = 1e-3 , verbose = False , warmstartT = True , random_state = 42 ,
828
- solver = 'PPA' , numItermax = 1 , log = False )
848
+ solver = 'PPA' , numItermax = 10 , log = False )
829
849
Xb , Cb = nx .to_numpy (Xb , Cb )
830
850
831
851
np .testing .assert_allclose (C , Cb , atol = 1e-06 )
@@ -1052,6 +1072,13 @@ def test_gromov_entropic_barycenter(nx):
1052
1072
1053
1073
C1b , C2b , p1b , p2b , pb = nx .from_numpy (C1 , C2 , p1 , p2 , p )
1054
1074
1075
+ with pytest .raises (ValueError ):
1076
+ loss_fun = 'weird_loss_fun'
1077
+ Cb = ot .gromov .entropic_gromov_barycenters (
1078
+ n_samples , [C1 , C2 ], None , p , [.5 , .5 ], loss_fun , 1e-3 ,
1079
+ max_iter = 10 , tol = 1e-3 , verbose = True , warmstartT = True , random_state = 42
1080
+ )
1081
+
1055
1082
Cb = ot .gromov .entropic_gromov_barycenters (
1056
1083
n_samples , [C1 , C2 ], None , p , [.5 , .5 ], 'square_loss' , 1e-3 ,
1057
1084
max_iter = 10 , tol = 1e-3 , verbose = True , warmstartT = True , random_state = 42
0 commit comments