@@ -1425,24 +1425,40 @@ def test_fgw_barycenter(nx):
1425
1425
init_C /= init_C .max ()
1426
1426
init_Cb = nx .from_numpy (init_C )
1427
1427
1428
- Xb , Cb = ot .gromov .fgw_barycenters (
1429
- n_samples , [ysb , ytb ], [C1b , C2b ], ps = [p1b , p2b ], lambdas = None ,
1430
- alpha = 0.5 , fixed_structure = True , init_C = init_Cb , fixed_features = False ,
1431
- p = None , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3
1432
- )
1428
+ try : # to raise warning when `fixed_structure=True`and `init_C=None`
1429
+ Xb , Cb = ot .gromov .fgw_barycenters (
1430
+ n_samples , [ysb , ytb ], [C1b , C2b ], ps = [p1b , p2b ], lambdas = None ,
1431
+ alpha = 0.5 , fixed_structure = True , init_C = None , fixed_features = False ,
1432
+ p = None , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3
1433
+ )
1434
+ except :
1435
+ Xb , Cb = ot .gromov .fgw_barycenters (
1436
+ n_samples , [ysb , ytb ], [C1b , C2b ], ps = [p1b , p2b ], lambdas = None ,
1437
+ alpha = 0.5 , fixed_structure = True , init_C = init_Cb , fixed_features = False ,
1438
+ p = None , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3
1439
+ )
1433
1440
Xb , Cb = nx .to_numpy (Xb ), nx .to_numpy (Cb )
1434
1441
np .testing .assert_allclose (Cb .shape , (n_samples , n_samples ))
1435
1442
np .testing .assert_allclose (Xb .shape , (n_samples , ys .shape [1 ]))
1436
1443
1437
1444
init_X = rng .randn (n_samples , ys .shape [1 ])
1438
1445
init_Xb = nx .from_numpy (init_X )
1439
1446
1440
- Xb , Cb , logb = ot .gromov .fgw_barycenters (
1441
- n_samples , [ysb , ytb ], [C1b , C2b ], [p1b , p2b ], [.5 , .5 ], 0.5 ,
1442
- fixed_structure = False , fixed_features = True , init_X = init_Xb ,
1443
- p = pb , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3 ,
1444
- warmstartT = True , log = True , random_state = 98765 , verbose = True
1445
- )
1447
+ try : # to raise warning when `fixed_features=True`and `init_X=None`
1448
+ Xb , Cb , logb = ot .gromov .fgw_barycenters (
1449
+ n_samples , [ysb , ytb ], [C1b , C2b ], [p1b , p2b ], [.5 , .5 ], 0.5 ,
1450
+ fixed_structure = False , fixed_features = True , init_X = None ,
1451
+ p = pb , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3 ,
1452
+ warmstartT = True , log = True , random_state = 98765 , verbose = True
1453
+ )
1454
+ except :
1455
+ Xb , Cb , logb = ot .gromov .fgw_barycenters (
1456
+ n_samples , [ysb , ytb ], [C1b , C2b ], [p1b , p2b ], [.5 , .5 ], 0.5 ,
1457
+ fixed_structure = False , fixed_features = True , init_X = init_Xb ,
1458
+ p = pb , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-3 ,
1459
+ warmstartT = True , log = True , random_state = 98765 , verbose = True
1460
+ )
1461
+
1446
1462
X , C = nx .to_numpy (Xb ), nx .to_numpy (Cb )
1447
1463
np .testing .assert_allclose (C .shape , (n_samples , n_samples ))
1448
1464
np .testing .assert_allclose (X .shape , (n_samples , ys .shape [1 ]))
0 commit comments