@@ -344,6 +344,188 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
344
344
return W
345
345
346
346
347
+ def bures_wasserstein_barycenter (m , C , weights = None , num_iter = 1000 , eps = 1e-7 , log = False ):
348
+ r"""Return OT linear operator between samples.
349
+
350
+ The function estimates the optimal barycenter of the
351
+ empirical distributions. This is equivalent to resolving the fixed point
352
+ algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n`
353
+ :ref:`[1] <references-OT-mapping-linear-barycenter>`.
354
+
355
+ The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)`
356
+ where :
357
+
358
+ .. math::
359
+ \mu_b = \sum_{i=1}^n w_i \mu_i
360
+
361
+ And the barycentric covariance is the solution of the following fixed-point algorithm:
362
+
363
+ .. math::
364
+ \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}
365
+
366
+
367
+ Parameters
368
+ ----------
369
+ m : array-like (k,d)
370
+ mean of k distributions
371
+ C : array-like (k,d,d)
372
+ covariance of k distributions
373
+ weights : array-like (k), optional
374
+ weights for each distribution
375
+ num_iter : int, optional
376
+ number of iteration for the fixed point algorithm
377
+ eps : float, optional
378
+ tolerance for the fixed point algorithm
379
+ log : bool, optional
380
+ record log if True
381
+
382
+
383
+ Returns
384
+ -------
385
+ mb : (d,) array-like
386
+ mean of the barycenter
387
+ Cb : (d, d) array-like
388
+ covariance of the barycenter
389
+ log : dict
390
+ log dictionary return only if log==True in parameters
391
+
392
+
393
+ .. _references-OT-mapping-linear-barycenter:
394
+ References
395
+ ----------
396
+ .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space",
397
+ SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
398
+ 2011.
399
+ """
400
+ nx = get_backend (* C , * m ,)
401
+
402
+ # Compute the mean barycenter
403
+ mb = nx .mean (m )
404
+
405
+ # Init the covariance barycenter
406
+ Cb = nx .mean (C , axis = 0 )
407
+
408
+ if weights is None :
409
+ weights = nx .ones (len (C ), type_as = C [0 ]) / len (C )
410
+
411
+ for it in range (num_iter ):
412
+ # fixed point update
413
+ Cb12 = nx .sqrtm (Cb )
414
+
415
+ Cnew = Cb12 @ C @ Cb12
416
+ C_ = []
417
+ for i in range (len (C )):
418
+ C_ .append (nx .sqrtm (Cnew [i ]))
419
+ Cnew = nx .stack (C_ , axis = 0 )
420
+ Cnew *= weights [:, None , None ]
421
+ Cnew = nx .sum (Cnew , axis = 0 )
422
+
423
+ # check convergence
424
+ diff = nx .norm (Cb - Cnew )
425
+ if diff <= eps :
426
+ break
427
+ Cb = Cnew
428
+ else :
429
+ print ("Dit not converge." )
430
+
431
+ if log :
432
+ log = {}
433
+ log ['num_iter' ] = it
434
+ log ['final_diff' ] = diff
435
+ return mb , Cb , log
436
+ else :
437
+ return mb , Cb
438
+
439
+
440
+ def empirical_bures_wasserstein_barycenter (
441
+ X , reg = 1e-6 , weights = None , num_iter = 1000 , eps = 1e-7 ,
442
+ w = None , bias = True , log = False
443
+ ):
444
+ r"""Return OT linear operator between samples.
445
+
446
+ The function estimates the optimal barycenter of the
447
+ empirical distributions. This is equivalent to resolving the fixed point
448
+ algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n`
449
+ :ref:`[1] <references-OT-mapping-linear-barycenter>`.
450
+
451
+ The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)`
452
+ where :
453
+
454
+ .. math::
455
+ \mu_b = \sum_{i=1}^n w_i \mu_i
456
+
457
+ And the barycentric covariance is the solution of the following fixed-point algorithm:
458
+
459
+ .. math::
460
+ \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}
461
+
462
+
463
+ Parameters
464
+ ----------
465
+ X : list of array-like (n,d)
466
+ samples in each distribution
467
+ reg : float,optional
468
+ regularization added to the diagonals of covariances (>0)
469
+ weights : array-like (n,), optional
470
+ weights for each distribution
471
+ num_iter : int, optional
472
+ number of iteration for the fixed point algorithm
473
+ eps : float, optional
474
+ tolerance for the fixed point algorithm
475
+ w : list of array-like (n,), optional
476
+ weights for each sample in each distribution
477
+ bias: boolean, optional
478
+ estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
479
+ log : bool, optional
480
+ record log if True
481
+
482
+
483
+ Returns
484
+ -------
485
+ mb : (d,) array-like
486
+ mean of the barycenter
487
+ Cb : (d, d) array-like
488
+ covariance of the barycenter
489
+ log : dict
490
+ log dictionary return only if log==True in parameters
491
+
492
+
493
+ .. _references-OT-mapping-linear-barycenter:
494
+ References
495
+ ----------
496
+ .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space",
497
+ SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
498
+ 2011.
499
+ """
500
+ X = list_to_array (* X )
501
+ nx = get_backend (* X )
502
+
503
+ k = len (X )
504
+ d = [X [i ].shape [1 ] for i in range (k )]
505
+
506
+ if bias :
507
+ m = [nx .mean (X [i ], axis = 0 )[None , :] for i in range (k )]
508
+ X = [X [i ] - m [i ] for i in range (k )]
509
+ else :
510
+ m = [nx .zeros ((1 , d [i ]), type_as = X [i ]) for i in range (k )]
511
+
512
+ if w is None :
513
+ w = [nx .ones ((X [i ].shape [0 ], 1 ), type_as = X [i ]) / X [i ].shape [0 ] for i in range (k )]
514
+
515
+ C = [
516
+ nx .dot ((X [i ] * w [i ]).T , X [i ]) / nx .sum (w [i ]) + reg * nx .eye (d [i ], type_as = X [i ])
517
+ for i in range (k )
518
+ ]
519
+ m = nx .stack (m , axis = 0 )
520
+ C = nx .stack (C , axis = 0 )
521
+ if log :
522
+ mb , Cb , log = bures_wasserstein_barycenter (m , C , weights = weights , num_iter = num_iter , eps = eps , log = log )
523
+ return mb , Cb , log
524
+ else :
525
+ mb , Cb = bures_wasserstein_barycenter (m , C , weights = weights , num_iter = num_iter , eps = eps , log = log )
526
+ return mb , Cb
527
+
528
+
347
529
def gaussian_gromov_wasserstein_distance (Cov_s , Cov_t , log = False ):
348
530
r""" Return the Gaussian Gromov-Wasserstein value from [57].
349
531
0 commit comments