@@ -1425,39 +1425,38 @@ def test_fgw_barycenter(nx):
1425
1425
init_C /= init_C .max ()
1426
1426
init_Cb = nx .from_numpy (init_C )
1427
1427
1428
- try : # to raise warning when `fixed_structure=True`and `init_C=None`
1428
+ with pytest . raises ( ot . utils . UndefinedParameter ) : # to raise warning when `fixed_structure=True`and `init_C=None`
1429
1429
Xb , Cb = ot .gromov .fgw_barycenters (
1430
1430
n_samples , [ysb , ytb ], [C1b , C2b ], ps = [p1b , p2b ], lambdas = None ,
1431
1431
alpha = 0.5 , fixed_structure = True , init_C = None , fixed_features = False ,
1432
1432
p = None , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3
1433
1433
)
1434
- except ot . utils . UndefinedParameter :
1435
- Xb , Cb = ot .gromov .fgw_barycenters (
1436
- n_samples , [ysb , ytb ], [C1b , C2b ], ps = [p1b , p2b ], lambdas = None ,
1437
- alpha = 0.5 , fixed_structure = True , init_C = init_Cb , fixed_features = False ,
1438
- p = None , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3
1439
- )
1434
+
1435
+ Xb , Cb = ot .gromov .fgw_barycenters (
1436
+ n_samples , [ysb , ytb ], [C1b , C2b ], ps = [p1b , p2b ], lambdas = None ,
1437
+ alpha = 0.5 , fixed_structure = True , init_C = init_Cb , fixed_features = False ,
1438
+ p = None , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3
1439
+ )
1440
1440
Xb , Cb = nx .to_numpy (Xb ), nx .to_numpy (Cb )
1441
1441
np .testing .assert_allclose (Cb .shape , (n_samples , n_samples ))
1442
1442
np .testing .assert_allclose (Xb .shape , (n_samples , ys .shape [1 ]))
1443
1443
1444
1444
init_X = rng .randn (n_samples , ys .shape [1 ])
1445
1445
init_Xb = nx .from_numpy (init_X )
1446
1446
1447
- try : # to raise warning when `fixed_features=True`and `init_X=None`
1447
+ with pytest . raises ( ot . utils . UndefinedParameter ) : # to raise warning when `fixed_features=True`and `init_X=None`
1448
1448
Xb , Cb , logb = ot .gromov .fgw_barycenters (
1449
1449
n_samples , [ysb , ytb ], [C1b , C2b ], [p1b , p2b ], [.5 , .5 ], 0.5 ,
1450
1450
fixed_structure = False , fixed_features = True , init_X = None ,
1451
1451
p = pb , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3 ,
1452
1452
warmstartT = True , log = True , random_state = 98765 , verbose = True
1453
1453
)
1454
- except ot .utils .UndefinedParameter :
1455
- Xb , Cb , logb = ot .gromov .fgw_barycenters (
1456
- n_samples , [ysb , ytb ], [C1b , C2b ], [p1b , p2b ], [.5 , .5 ], 0.5 ,
1457
- fixed_structure = False , fixed_features = True , init_X = init_Xb ,
1458
- p = pb , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3 ,
1459
- warmstartT = True , log = True , random_state = 98765 , verbose = True
1460
- )
1454
+ Xb , Cb , logb = ot .gromov .fgw_barycenters (
1455
+ n_samples , [ysb , ytb ], [C1b , C2b ], [p1b , p2b ], [.5 , .5 ], 0.5 ,
1456
+ fixed_structure = False , fixed_features = True , init_X = init_Xb ,
1457
+ p = pb , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3 ,
1458
+ warmstartT = True , log = True , random_state = 98765 , verbose = True
1459
+ )
1461
1460
1462
1461
X , C = nx .to_numpy (Xb ), nx .to_numpy (Cb )
1463
1462
np .testing .assert_allclose (C .shape , (n_samples , n_samples ))
0 commit comments